You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

msgraph.py 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the MindSpore graph."""
  16. from mindinsight.datavisual.common.log import logger
  17. from mindinsight.datavisual.proto_files.mindinsight_anf_ir_pb2 import DataType
  18. from mindinsight.datavisual.common.enums import PluginNameEnum
  19. from .node import Node
  20. from .node import NodeTypeEnum
  21. from .graph import Graph
  22. from .graph import EdgeTypeEnum
  23. class MSGraph(Graph):
  24. """The object describes the MindSpore graph, and it is defined in the anf_ir proto file."""
  25. def _parse_data(self, proto_data):
  26. """
  27. The proto data is parsed and all nodes are stored in the specified structure.
  28. Args:
  29. proto_data (anf_ir_pb2.GraphProto): Refer to anf_ir_pb2.GraphProto object.
  30. """
  31. logger.info("Start to parse graph proto data.")
  32. self._parse_op_nodes(proto_data.node)
  33. self._parse_parameters(proto_data.parameters)
  34. self._parse_consts(proto_data.const_vals)
  35. self._update_input_after_create_node()
  36. self._update_output_after_create_node()
  37. logger.info("Parse proto data end, normal node count(only contain op node, "
  38. "parameter, const): %s.", self.normal_node_count)
  39. def _parse_op_nodes(self, node_protos):
  40. """
  41. Parse `anf_ir_pb2.NodeProto` object, and create a normal node.
  42. Args:
  43. node_protos (list[anf_ir_pb2.NodeProto]): Refer to anf_ir_pb2.NodeProto.
  44. """
  45. logger.debug("Start to parse op nodes from proto.")
  46. for node_proto in node_protos:
  47. if not node_proto.name:
  48. logger.warning("Finding a node with an empty name will not save it.")
  49. continue
  50. if not node_proto.full_name or any(
  51. node_proto.full_name.lower().endswith(f'[:{plugin.value.lower()}]') for plugin in PluginNameEnum):
  52. node_name = Node.create_node_name(scope=node_proto.scope,
  53. base_name=f'{node_proto.op_type}{node_proto.name}')
  54. else:
  55. node_name = node_proto.full_name
  56. node = Node(name=node_name, node_id=node_proto.name)
  57. node.type = node_proto.op_type
  58. self._parse_attributes(node_proto.attribute, node)
  59. self._parse_inputs(node_proto.input, node)
  60. node.output_i = node_proto.output_i
  61. node.scope = node_proto.scope
  62. node.output_shape = self._get_shape_by_parse_type_proto(node_proto.output_type)
  63. node.output_data_type = self._get_data_type_by_parse_type_proto(node_proto.output_type)
  64. self._cache_node(node)
  65. def _parse_parameters(self, parameter_protos):
  66. """
  67. Parse `anf_ir_pb2.ParameterProto` object, and create a parameter node.
  68. Args:
  69. parameter_protos (list[anf_ir_pb2.ParameterProto]): Refer to anf_ir_pb2.ParameterProto.
  70. """
  71. logger.debug("Start to parse parameters from proto.")
  72. for parameter in parameter_protos:
  73. if not parameter.name:
  74. logger.warning("Finding a parameter with an empty name will not save it.")
  75. continue
  76. node = Node(name=parameter.name, node_id=parameter.name)
  77. node.type = NodeTypeEnum.PARAMETER.value
  78. node.output_data_type = self._get_data_type_by_parse_type_proto(parameter.type)
  79. node.output_shape = self._get_shape_by_parse_type_proto(parameter.type)
  80. attr = dict(
  81. type=self._get_data_type_by_parse_type_proto(parameter.type),
  82. shape=str(self._get_shape_by_parse_type_proto(parameter.type))
  83. )
  84. node.add_attr(attr)
  85. self._cache_node(node)
  86. logger.debug("Foreach graph proto parameters, node id: %s, node name: %s, "
  87. "node def name: %s", node.node_id, node.name, parameter.name)
  88. def _parse_consts(self, consts):
  89. """
  90. Parse `anf_ir_pb2.NameValueProto` object, and create a const node.
  91. Args:
  92. consts (list[anf_ir_pb2.NameValueProto]): Refer to `anf_ir_pb2.NameValueProto` object.
  93. """
  94. logger.debug("Start to parse consts from proto.")
  95. for const in consts:
  96. if not const.key:
  97. logger.warning("Finding a const with an empty key will not save it.")
  98. continue
  99. node = Node(name=const.key, node_id=const.key)
  100. node.type = NodeTypeEnum.CONST.value
  101. node.add_attr({const.key: str(const.value)})
  102. if const.value.dtype == DataType.DT_TENSOR:
  103. shape = []
  104. for dim in const.value.tensor_val.dims:
  105. shape.append(dim)
  106. node.output_shape = shape
  107. self._cache_node(node)
  108. def _get_shape_by_parse_type_proto(self, type_proto):
  109. """
  110. Parse proto's `message TypeProto` to get shape information.
  111. Args:
  112. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  113. Returns:
  114. list, a list of shape.
  115. """
  116. shapes = []
  117. if type_proto.HasField('tensor_type'):
  118. tensor_type = type_proto.tensor_type
  119. tensor_shape_proto = tensor_type.shape
  120. for dim in tensor_shape_proto.dim:
  121. shapes.append(dim.size)
  122. if type_proto.HasField('sequence_type'):
  123. for elem_type in type_proto.sequence_type.elem_types:
  124. shapes.append(self._get_shape_by_parse_type_proto(elem_type))
  125. return shapes
  126. def _get_data_type_by_parse_type_proto(self, type_proto):
  127. """
  128. Get data type by parse type proto object.
  129. The name of the DataType, refer to `anf_ir_pb2.DataType` object.
  130. If data type is tensor or tuple, the data name we return is `data_type[element_type, element_type]`.
  131. Args:
  132. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  133. Returns:
  134. str, the data type.
  135. """
  136. data_type_name = self._get_data_type_name_by_value(type_proto, type_proto.data_type, field_name='data_type')
  137. if type_proto.data_type == DataType.DT_TENSOR:
  138. tensor_type_proto = type_proto.tensor_type
  139. value = type_proto.tensor_type.elem_type
  140. elem_type_name = self._get_data_type_name_by_value(tensor_type_proto, value, field_name='elem_type')
  141. return f'{data_type_name}[{elem_type_name}]'
  142. if type_proto.data_type == DataType.DT_TUPLE:
  143. data_types = []
  144. for elem_type in type_proto.sequence_type.elem_types:
  145. data_types.append(self._get_data_type_by_parse_type_proto(elem_type))
  146. return f'{data_type_name}{str(data_types)}'
  147. return data_type_name
  148. def _parse_inputs(self, input_protos, node):
  149. """
  150. Parse `anf_ir_pb2.InputProto` object.
  151. Args:
  152. input_protos (list[anf_ir_pb2.InputProto]): Refer to `anf_ir_pb2.InputProto` object.
  153. node (Node): Refer to `Node` object, it is used to log message and update input.
  154. """
  155. for input_proto in input_protos:
  156. if not input_proto.name:
  157. logger.warning("The name in input proto of node(%s) is empty, will ignore.", node.name)
  158. continue
  159. edge_type = EdgeTypeEnum.DATA.value if not input_proto.type else EdgeTypeEnum.CONTROL.value
  160. # Notice:
  161. # 1. The name in the input proto is the node id of the Node object.
  162. # 2. In the current step, the shape of source node cannot be obtained,
  163. # so it is set to empty list by default, and the next step will update it.
  164. # 3. Same with scope, set the default value first.
  165. input_attr = {
  166. "shape": [],
  167. "edge_type": edge_type,
  168. "independent_layout": False,
  169. 'data_type': ''
  170. }
  171. node.add_input(src_name=input_proto.name, input_attr=input_attr)
  172. def _parse_attributes(self, attributes, node):
  173. """
  174. Parse `anf_ir_pb2.AttributeProto` object., and Filters large attribute values.
  175. Args:
  176. attributes (list[anf_ir_pb2.AttributeProto]): Refer to `anf_ir_pb2.AttributeProto` object.
  177. node (Node): Refer to `Node` object, it is used to log message and update attr.
  178. """
  179. for attr in attributes:
  180. if attr.value.ByteSize() > self.MAX_NODE_ATTRIBUTE_VALUE_BYTES:
  181. message = f"The attribute value of node({node.name}) " \
  182. f"is over {self.MAX_NODE_ATTRIBUTE_VALUE_BYTES} Bytes, will ignore."
  183. logger.warning(message)
  184. continue
  185. node.add_attr({attr.name: str(attr.value)})
  186. def _update_input_after_create_node(self):
  187. """Update the input of node after create node."""
  188. for node in self._normal_node_map.values():
  189. for src_node_id, input_attr in dict(node.input).items():
  190. node.delete_input(src_node_id)
  191. if not self._is_node_exist(node_id=src_node_id):
  192. message = f"The input node could not be found by node id({src_node_id}) " \
  193. f"while updating the input of the node({node})"
  194. logger.warning(message)
  195. continue
  196. src_node = self._get_normal_node(node_id=src_node_id)
  197. input_attr['shape'] = src_node.output_shape
  198. input_attr['data_type'] = src_node.output_data_type
  199. node.add_input(src_name=src_node.name, input_attr=input_attr)
  200. def _update_output_after_create_node(self):
  201. """Update the output of node after create node."""
  202. # Constants and parameter should not exist for input and output.
  203. filtered_node = {NodeTypeEnum.CONST.value, NodeTypeEnum.PARAMETER.value}
  204. for node in self._normal_node_map.values():
  205. for src_name, input_attr in node.input.items():
  206. src_node = self._get_normal_node(node_name=src_name)
  207. if src_node.type in filtered_node:
  208. continue
  209. src_node.add_output(node.name, input_attr)
  210. @staticmethod
  211. def _get_data_type_name_by_value(data_type, value, field_name='data_type'):
  212. """Get the data type name by the enum value, data_type refer to `DataType` object."""
  213. return data_type.DESCRIPTOR.fields_by_name[field_name].enum_type.values_by_number[value].name