You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

msgraph.py 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the MindSpore graph."""
  16. import time
  17. from mindinsight.datavisual.common.log import logger
  18. from mindinsight.datavisual.proto_files.mindinsight_anf_ir_pb2 import DataType
  19. from mindinsight.datavisual.common.enums import PluginNameEnum
  20. from .node import Node
  21. from .node import NodeTypeEnum
  22. from .graph import Graph
  23. from .graph import EdgeTypeEnum
  24. class MSGraph(Graph):
  25. """The object describes the MindSpore graph, and it is defined in the anf_ir proto file."""
  26. def build_graph(self, proto_data):
  27. """
  28. Build graph by graph proto which refer to `anf_ir_pb2.GraphProto`.
  29. Args:
  30. proto_data (anf_ir_pb2.GraphProto): Refer to `anf_ir_pb2.GraphProto`.
  31. """
  32. logger.info("Start to build graph, graph name: %s.", proto_data.name)
  33. start_time = time.time()
  34. super(MSGraph, self).build_graph(proto_data)
  35. precision = 6
  36. time_consuming = round(time.time()-start_time, precision)
  37. logger.info("Build graph end, all node count: %s, const count: %s, parameter count: %s, time-consuming: %s s.",
  38. self.normal_node_count, len(self._const_node_temp_cache),
  39. len(self._parameter_node_temp_cache), time_consuming)
  40. def _parse_data(self, proto_data):
  41. """
  42. The proto data is parsed and all nodes are stored in the specified structure.
  43. Args:
  44. proto_data (anf_ir_pb2.GraphProto): Refer to anf_ir_pb2.GraphProto object.
  45. """
  46. logger.info("Start to parse graph proto data.")
  47. self._parse_op_nodes(proto_data.node)
  48. self._parse_parameters(proto_data.parameters)
  49. self._parse_consts(proto_data.const_vals)
  50. self._update_input_after_create_node()
  51. self._update_output_after_create_node()
  52. logger.info("Parse proto data end, normal node count(only contain op node, "
  53. "parameter, const): %s.", self.normal_node_count)
  54. def _parse_op_nodes(self, node_protos):
  55. """
  56. Parse `anf_ir_pb2.NodeProto` object, and create a normal node.
  57. Args:
  58. node_protos (list[anf_ir_pb2.NodeProto]): Refer to anf_ir_pb2.NodeProto.
  59. """
  60. logger.debug("Start to parse op nodes from proto.")
  61. for node_proto in node_protos:
  62. if not node_proto.name:
  63. logger.warning("Finding a node with an empty name will not save it.")
  64. continue
  65. if not node_proto.full_name or any(
  66. node_proto.full_name.lower().endswith(f'[:{plugin.value.lower()}]') for plugin in PluginNameEnum):
  67. node_name = Node.create_node_name(scope=node_proto.scope,
  68. base_name=f'{node_proto.op_type}{node_proto.name}')
  69. else:
  70. node_name = node_proto.full_name
  71. node = Node(name=node_name, node_id=node_proto.name)
  72. node.type = node_proto.op_type
  73. logger.debug("Foreach graph proto nodes, node id: %s, node name: %s, node def name: %s, "
  74. "input count: %s", node.node_id, node.name, node_proto.name, len(node_proto.input))
  75. self._parse_attributes(node_proto.attribute, node)
  76. self._parse_inputs(node_proto.input, node)
  77. node.output_i = node_proto.output_i
  78. node.scope = node_proto.scope
  79. node.output_shape = self._get_shape_by_parse_type_proto(node_proto.output_type)
  80. node.output_data_type = self._get_data_type_by_parse_type_proto(node_proto.output_type)
  81. self._cache_node(node)
  82. def _parse_parameters(self, parameter_protos):
  83. """
  84. Parse `anf_ir_pb2.ParameterProto` object, and create a parameter node.
  85. Args:
  86. parameter_protos (list[anf_ir_pb2.ParameterProto]): Refer to anf_ir_pb2.ParameterProto.
  87. """
  88. logger.debug("Start to parse parameters from proto.")
  89. for parameter in parameter_protos:
  90. if not parameter.name:
  91. logger.warning("Finding a parameter with an empty name will not save it.")
  92. continue
  93. node = Node(name=parameter.name, node_id=parameter.name)
  94. node.type = NodeTypeEnum.PARAMETER.value
  95. node.output_shape = self._get_shape_by_parse_type_proto(parameter.type)
  96. attr = dict(
  97. type=self._get_data_type_by_parse_type_proto(parameter.type),
  98. shape=str(self._get_shape_by_parse_type_proto(parameter.type))
  99. )
  100. node.add_attr(attr)
  101. self._cache_node(node)
  102. logger.debug("Foreach graph proto parameters, node id: %s, node name: %s, "
  103. "node def name: %s", node.node_id, node.name, parameter.name)
  104. def _parse_consts(self, consts):
  105. """
  106. Parse `anf_ir_pb2.NameValueProto` object, and create a const node.
  107. Args:
  108. consts (list[anf_ir_pb2.NameValueProto]): Refer to `anf_ir_pb2.NameValueProto` object.
  109. """
  110. logger.debug("Start to parse consts from proto.")
  111. for const in consts:
  112. if not const.key:
  113. logger.warning("Finding a const with an empty key will not save it.")
  114. continue
  115. node = Node(name=const.key, node_id=const.key)
  116. node.type = NodeTypeEnum.CONST.value
  117. node.add_attr({const.key: str(const.value)})
  118. if const.value.dtype == DataType.DT_TENSOR:
  119. shape = []
  120. for dim in const.value.tensor_val.dims:
  121. shape.append(dim)
  122. node.output_shape = shape
  123. self._cache_node(node)
  124. def _get_shape_by_parse_type_proto(self, type_proto):
  125. """
  126. Parse proto's `message TypeProto` to get shape information.
  127. Args:
  128. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  129. Returns:
  130. list, a list of shape.
  131. """
  132. shapes = []
  133. if type_proto.HasField('tensor_type'):
  134. tensor_type = type_proto.tensor_type
  135. tensor_shape_proto = tensor_type.shape
  136. for dim in tensor_shape_proto.dim:
  137. shapes.append(dim.size)
  138. if type_proto.HasField('sequence_type'):
  139. for elem_type in type_proto.sequence_type.elem_types:
  140. shapes.append(self._get_shape_by_parse_type_proto(elem_type))
  141. return shapes
  142. def _get_data_type_by_parse_type_proto(self, type_proto):
  143. """
  144. Get data type by parse type proto object.
  145. The name of the DataType, refer to `anf_ir_pb2.DataType` object.
  146. If data type is tensor or tuple, the data name we return is `data_type[element_type, element_type]`.
  147. Args:
  148. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  149. Returns:
  150. str, the data type.
  151. """
  152. data_type_name = self._get_data_type_name_by_value(type_proto, type_proto.data_type, field_name='data_type')
  153. if type_proto.data_type == DataType.DT_TENSOR:
  154. tensor_type_proto = type_proto.tensor_type
  155. value = type_proto.tensor_type.elem_type
  156. elem_type_name = self._get_data_type_name_by_value(tensor_type_proto, value, field_name='elem_type')
  157. return f'{data_type_name}[{elem_type_name}]'
  158. if type_proto.data_type == DataType.DT_TUPLE:
  159. data_types = []
  160. for elem_type in type_proto.sequence_type.elem_types:
  161. data_types.append(self._get_data_type_by_parse_type_proto(elem_type))
  162. return f'{data_type_name}{str(data_types)}'
  163. return data_type_name
  164. def _parse_inputs(self, input_protos, node):
  165. """
  166. Parse `anf_ir_pb2.InputProto` object.
  167. Args:
  168. input_protos (list[anf_ir_pb2.InputProto]): Refer to `anf_ir_pb2.InputProto` object.
  169. node (Node): Refer to `Node` object, it is used to log message and update input.
  170. """
  171. for input_proto in input_protos:
  172. if not input_proto.name:
  173. logger.warning("The name in input proto of node(%s) is empty, will ignore.", node.name)
  174. continue
  175. edge_type = EdgeTypeEnum.DATA.value if not input_proto.type else EdgeTypeEnum.CONTROL.value
  176. # Notice:
  177. # 1. The name in the input proto is the node id of the Node object.
  178. # 2. In the current step, the shape of source node cannot be obtained,
  179. # so it is set to empty list by default, and the next step will update it.
  180. # 3. Same with scope, set the default value first.
  181. input_attr = {
  182. "shape": [],
  183. "edge_type": edge_type,
  184. "independent_layout": False,
  185. 'data_type': ''
  186. }
  187. node.add_input(src_name=input_proto.name, input_attr=input_attr)
  188. def _parse_attributes(self, attributes, node):
  189. """
  190. Parse `anf_ir_pb2.AttributeProto` object., and Filters large attribute values.
  191. Args:
  192. attributes (list[anf_ir_pb2.AttributeProto]): Refer to `anf_ir_pb2.AttributeProto` object.
  193. node (Node): Refer to `Node` object, it is used to log message and update attr.
  194. """
  195. for attr in attributes:
  196. if attr.value.ByteSize() > self.MAX_NODE_ATTRIBUTE_VALUE_BYTES:
  197. message = f"The attribute value of node({node.name}) " \
  198. f"is over {self.MAX_NODE_ATTRIBUTE_VALUE_BYTES} Bytes, will ignore."
  199. logger.info(message)
  200. continue
  201. node.add_attr({attr.name: str(attr.value)})
  202. def _update_input_after_create_node(self):
  203. """Update the input of node after create node."""
  204. for node in self._normal_node_map.values():
  205. for src_node_id, input_attr in dict(node.input).items():
  206. node.delete_input(src_node_id)
  207. if not self._is_node_exist(node_id=src_node_id):
  208. message = f"The input node could not be found by node id({src_node_id}) " \
  209. f"while updating the input of the node({node})"
  210. logger.warning(message)
  211. continue
  212. src_node = self._get_normal_node(node_id=src_node_id)
  213. input_attr['shape'] = src_node.output_shape
  214. input_attr['data_type'] = src_node.output_data_type
  215. node.add_input(src_name=src_node.name, input_attr=input_attr)
  216. def _update_output_after_create_node(self):
  217. """Update the output of node after create node."""
  218. # Constants and parameter should not exist for input and output.
  219. filtered_node = {NodeTypeEnum.CONST.value, NodeTypeEnum.PARAMETER.value}
  220. for node in self._normal_node_map.values():
  221. for src_name, input_attr in node.input.items():
  222. src_node = self._get_normal_node(node_name=src_name)
  223. if src_node.type in filtered_node:
  224. continue
  225. src_node.add_output(node.name, input_attr)
  226. @staticmethod
  227. def _get_data_type_name_by_value(data_type, value, field_name='data_type'):
  228. """Get the data type name by the enum value, data_type refer to `DataType` object."""
  229. return data_type.DESCRIPTOR.fields_by_name[field_name].enum_type.values_by_number[value].name