You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

msgraph.py 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago

  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the MindSpore graph."""
  16. import re
  17. import copy
  18. from mindinsight.datavisual.common.log import logger
  19. from .node import Node
  20. from .node import NodeTypeEnum
  21. from .graph import Graph
  22. from .graph import EdgeTypeEnum
  23. from .graph import DataTypeEnum
  24. class MSGraph(Graph):
  25. """The object describes the MindSpore graph, and it is defined in the anf_if proto file."""
  26. def build_graph(self, graph_proto):
  27. """
  28. Build graph by graph proto which refer to `anf_ir_pb2.GraphProto`, and set status to loading.
  29. Args:
  30. graph_proto (anf_ir_pb2.GraphProto): Refer to `anf_ir_pb2.GraphProto`.
  31. """
  32. logger.info("Start to build graph.")
  33. self._build_leaf_nodes(graph_proto)
  34. self._build_polymeric_nodes()
  35. self._build_name_scope_nodes()
  36. self._update_polymeric_input_output()
  37. logger.info("Build graph end, normal node count: %s, polymeric node "
  38. "count: %s.", len(self._normal_nodes), len(self._polymeric_nodes))
  39. def _build_leaf_nodes(self, graph_proto):
  40. """
  41. Build leaf node from graph proto.
  42. Left node will contain operation node, parameter node, const node.
  43. Args:
  44. graph_proto (anf_ir_pb2.model_proto.graph): Refer to anf_ir_pb2.model_proto.graph.
  45. """
  46. logger.info("Start to build leaf nodes.")
  47. leaf_node_id_map_name = {}
  48. const_nodes_map = {}
  49. for node_def in graph_proto.node:
  50. if not node_def.name:
  51. logger.warning("Finding a node with an empty name will not save it.")
  52. continue
  53. node = self._parse_graph_proto_node(node_def)
  54. leaf_node_id_map_name.update({node.node_id: node.name})
  55. for parameter in graph_proto.parameters:
  56. if not parameter.name:
  57. logger.warning("Finding a parameter with an empty name will not save it.")
  58. continue
  59. node = self._parse_graph_proto_parameter(parameter)
  60. const_nodes_map.update({node.name: node})
  61. for i, const in enumerate(graph_proto.const_vals):
  62. if not const.key:
  63. logger.warning("Finding a const with an empty key will not save it.")
  64. continue
  65. node_id = 'const_{}'.format(i)
  66. node = self._parse_graph_proto_const(const, node_id)
  67. const_nodes_map.update({const.key: node})
  68. self._calc_input(leaf_node_id_map_name, graph_proto, const_nodes_map)
  69. self._calc_output()
  70. logger.info("Build leaf nodes end, normal nodes count: %s, group count: %s, "
  71. "leaf nodes count: %s.", len(self._normal_nodes), len(self._node_groups),
  72. len(self._leaf_nodes))
  73. def _calc_input(self, leaf_node_id_map_name, graph_proto, const_nodes_map):
  74. """
  75. Calc input for every leaf node.
  76. Args:
  77. leaf_node_id_map_name (dict[str, str]): Format is {'node_id': 'node_name'}.
  78. graph_proto (anf_ir_pb2.model_proto.graph): See anf_ir_pb2.model_proto.graph.
  79. const_nodes_map (dict[str, Node]): Format is {'node name': <Const node>}.
  80. """
  81. logger.debug("Start to calc input.")
  82. for node_def in graph_proto.node:
  83. if not node_def.name:
  84. logger.debug("The node name is empty, ignore it.")
  85. continue
  86. node_name = leaf_node_id_map_name[node_def.name]
  87. node = self._leaf_nodes[node_name]
  88. for input_def in node_def.input:
  89. if not input_def.name:
  90. logger.warning("The input node name is empty, ignore it. node name: %s.", node_name)
  91. continue
  92. edge_type = EdgeTypeEnum.DATA.value
  93. if input_def.type == "CONTROL_EDGE":
  94. edge_type = EdgeTypeEnum.CONTROL.value
  95. if const_nodes_map.get(input_def.name):
  96. const_node = copy.deepcopy(const_nodes_map[input_def.name])
  97. src_name = '{}/{}'.format(node.name_scope, input_def.name)
  98. if not self._normal_nodes.get(src_name):
  99. const_node.name = src_name
  100. const_node.name_scope = node.name_scope
  101. self._normal_nodes.update({src_name: const_node})
  102. self._leaf_nodes.update({src_name: const_node})
  103. src_node = self._leaf_nodes.get(src_name)
  104. else:
  105. src_name = leaf_node_id_map_name.get(input_def.name)
  106. if not src_name:
  107. logger.warning("The input_def name '%s' in node '%s' is invalid, "
  108. "will be ignore.", input_def.name, node_name)
  109. continue
  110. src_node = self._leaf_nodes.get(src_name)
  111. if src_node is None:
  112. logger.warning("The input '%s' in node '%s' is not in "
  113. "leaf nodes.", src_name, node_name)
  114. continue
  115. input_item = {
  116. src_name: {
  117. "shape": src_node.shape,
  118. "edge_type": edge_type,
  119. "scope": NodeTypeEnum.NAME_SCOPE.value
  120. }
  121. }
  122. node.update_input(input_item)
  123. if self._normal_nodes.get(node_name):
  124. self._normal_nodes[node_name] = node
  125. else:
  126. group_name = self._create_group_name(node.name_scope, node.node_type, node.name)
  127. self._node_groups[group_name][node.name] = node
  128. def _calc_output(self):
  129. """Calc output of every node."""
  130. logger.debug("Start to calc output.")
  131. for name, node in self._leaf_nodes.items():
  132. if node.node_type == NodeTypeEnum.CONST.value:
  133. continue
  134. for src_name, input_attr in node.inputs.items():
  135. src_node = self._leaf_nodes[src_name]
  136. if src_node.node_type == NodeTypeEnum.CONST.value:
  137. continue
  138. if self._normal_nodes.get(src_name):
  139. self._normal_nodes[src_name].update_output({name: input_attr})
  140. else:
  141. group_name = self._create_group_name(src_node.name_scope,
  142. src_node.node_type, src_node.name)
  143. self._node_groups[group_name][src_name].update_output({name: input_attr})
  144. def _parse_graph_proto_node(self, node_def):
  145. """
  146. Parse `anf_ir_pb2.model_proto.graph.node_def`, and create a a node.
  147. Args:
  148. node_def (anf_ir_pb2.model_proto.graph.node_def): Refer to anf_ir_pb2.model_proto.graph.node_def.
  149. Returns:
  150. Node, a `Node` object.
  151. """
  152. node_name = '/'.join([node_def.scope, node_def.op_type]) + node_def.name \
  153. if node_def.scope else node_def.op_type + node_def.name
  154. node = Node(name=node_name, node_id=node_def.name)
  155. node.node_type = node_def.op_type
  156. logger.debug("Foreach graph proto nodes, node id: %s, node name: %s, node def name: %s, "
  157. "input count: %s", node.node_id, node.name, node_def.name, len(node_def.input))
  158. for attr in node_def.attribute:
  159. node.update_attr({attr.name: str(attr.value)})
  160. node.output_i = node_def.output_i
  161. node.name_scope = node_def.scope
  162. output_type = node_def.output_type
  163. shape = self._parse_type_proto(output_type)
  164. node.shape = shape
  165. self._leaf_nodes.update({node.name: node})
  166. group_name = self._create_group_name(node.name_scope, node.node_type, node.name)
  167. if group_name is not None:
  168. node_dict = self._node_groups.get(group_name, {})
  169. node_dict.update({node.name: node})
  170. self._node_groups.update({group_name: node_dict})
  171. else:
  172. self._normal_nodes.update({node.name: node})
  173. return node
  174. def _parse_graph_proto_parameter(self, parameter):
  175. """
  176. Parse anf_ir_pb2.model_proto.graph.parameter, and create a parameter node.
  177. Args:
  178. parameter (anf_ir_pb2.model_proto.graph.parameter): Refer to anf_ir_pb2.model_proto.graph.parameter.
  179. Returns:
  180. Node, a `Node` object.
  181. """
  182. node = Node(name=parameter.name, node_id=parameter.name)
  183. node.node_type = NodeTypeEnum.PARAMETER.value
  184. node.shape = self._parse_type_proto(parameter.type)
  185. logger.debug("Foreach graph proto parameters, node id: %s, node name: %s, "
  186. "node def name: %s", node.node_id, node.name, parameter.name)
  187. return node
  188. def _parse_graph_proto_const(self, const, const_node_id):
  189. """
  190. Parse anf_ir_pb2.model_proto.graph.const, and create a const node.
  191. Args:
  192. const (anf_ir_pb2.model_proto.graph.const): Refer to anf_ir_pb2.model_proto.graph.const
  193. const_node_id (str): The id of the new const node, it should be unique in graph.
  194. Returns:
  195. Node, a `Node` object.
  196. """
  197. node = Node(name=const.key, node_id=const_node_id)
  198. node.node_type = NodeTypeEnum.CONST.value
  199. node.update_attr({const.key: str(const.value)})
  200. if const.value.dtype == DataTypeEnum.DT_TENSOR.value:
  201. shape = []
  202. for dim in const.value.tensor_val.dims:
  203. shape.append(dim)
  204. node.shape = shape
  205. return node
  206. def _parse_type_proto(self, type_proto):
  207. """
  208. Parse proto's `message TypeProto` to get shape information.
  209. Args:
  210. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  211. Returns:
  212. list, a list of shape.
  213. """
  214. shapes = []
  215. if type_proto.HasField('tensor_type'):
  216. tensor_type = type_proto.tensor_type
  217. tensor_shape_proto = tensor_type.shape
  218. for dim in tensor_shape_proto.dim:
  219. shapes.append(dim.size)
  220. if type_proto.HasField('sequence_type'):
  221. for elem_type in type_proto.sequence_type.elem_types:
  222. shapes.append(self._parse_type_proto(elem_type))
  223. return shapes
  224. def _create_group_name(self, name_scope, node_type, node_name):
  225. """
  226. Create group name by node name, name scope, node type.
  227. Only nodes that conform to the rules are aggregated.
  228. Args:
  229. name_scope (str): The node name scope.
  230. node_type (str): The node type.
  231. node_name (str): The node name.
  232. Returns:
  233. Optional[str], if match the rules will return a group name, else return None.
  234. """
  235. group_types = ['Reshape', 'Variable']
  236. pattern_names = r'.*?/Cast-op\d+'
  237. if node_type in group_types:
  238. group_name = name_scope + '/' + node_type if name_scope else node_type
  239. return group_name
  240. if node_type == 'FrameworkOp' and re.search(pattern_names, node_name):
  241. group_name = name_scope + '/' + 'Cast-op' if name_scope else 'Cast-op'
  242. return group_name
  243. return None

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。