You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

msgraph.py 13 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the MindSpore graph."""
  16. from mindinsight.datavisual.common.log import logger
  17. from mindinsight.datavisual.proto_files.mindinsight_anf_ir_pb2 import DataType
  18. from mindinsight.datavisual.common.enums import PluginNameEnum
  19. from .node import Node
  20. from .node import NodeTypeEnum
  21. from .graph import Graph
  22. from .graph import EdgeTypeEnum
  23. from .graph import escape_html
  24. class MSGraph(Graph):
  25. """The object describes the MindSpore graph, and it is defined in the anf_ir proto file."""
  26. def _parse_data(self, proto_data):
  27. """
  28. The proto data is parsed and all nodes are stored in the specified structure.
  29. Args:
  30. proto_data (anf_ir_pb2.GraphProto): Refer to anf_ir_pb2.GraphProto object.
  31. """
  32. logger.info("Start to parse graph proto data.")
  33. self._parse_op_nodes(proto_data.node)
  34. self._parse_parameters(proto_data.parameters)
  35. self._parse_consts(proto_data.const_vals)
  36. self._update_input_after_create_node()
  37. self._update_output_after_create_node()
  38. logger.info("Parse proto data end, normal node count(only contain op node, "
  39. "parameter, const): %s.", self.normal_node_count)
  40. def _parse_op_nodes(self, node_protos):
  41. """
  42. Parse `anf_ir_pb2.NodeProto` object, and create a normal node.
  43. Args:
  44. node_protos (list[anf_ir_pb2.NodeProto]): Refer to anf_ir_pb2.NodeProto.
  45. """
  46. logger.debug("Start to parse op nodes from proto.")
  47. for node_proto in node_protos:
  48. if not node_proto.name:
  49. logger.warning("Finding a node with an empty name will not save it.")
  50. continue
  51. if not node_proto.full_name or any(
  52. node_proto.full_name.lower().endswith(f'[:{plugin.value.lower()}]') for plugin in PluginNameEnum):
  53. node_name = Node.create_node_name(scope=node_proto.scope,
  54. base_name=f'{node_proto.op_type}{node_proto.name}')
  55. else:
  56. node_name = node_proto.full_name
  57. # Because the Graphviz plug-in that the UI USES can't handle these special characters,
  58. # the special characters are HTML escaped to avoid UI crash.
  59. # Doing this on the backend prevents the frontend from doing it every time.
  60. node_name = escape_html(node_name)
  61. node = Node(name=node_name, node_id=node_proto.name)
  62. node.full_name = node_proto.full_name
  63. node.type = node_proto.op_type
  64. self._parse_attributes(node_proto.attribute, node)
  65. self._parse_inputs(node_proto.input, node)
  66. node.output_i = node_proto.output_i
  67. node.scope = node_proto.scope
  68. node.output_shape = self._get_shape_by_parse_type_proto(node_proto.output_type)
  69. node.output_nums = len(node.output_shape)
  70. node.output_data_type = self._get_data_type_by_parse_type_proto(node_proto.output_type, node)
  71. self._cache_node(node)
  72. def _parse_parameters(self, parameter_protos):
  73. """
  74. Parse `anf_ir_pb2.ParameterProto` object, and create a parameter node.
  75. Args:
  76. parameter_protos (list[anf_ir_pb2.ParameterProto]): Refer to anf_ir_pb2.ParameterProto.
  77. """
  78. logger.debug("Start to parse parameters from proto.")
  79. for parameter in parameter_protos:
  80. if not parameter.name:
  81. logger.warning("Finding a parameter with an empty name will not save it.")
  82. continue
  83. node = Node(name=parameter.name, node_id=parameter.name)
  84. node.type = NodeTypeEnum.PARAMETER.value
  85. node.output_shape = self._get_shape_by_parse_type_proto(parameter.type)
  86. node.output_nums = len(node.output_shape)
  87. node.output_data_type = self._get_data_type_by_parse_type_proto(parameter.type, node)
  88. attr = dict(
  89. type=self._get_data_type_by_parse_type_proto(parameter.type, node),
  90. shape=str(self._get_shape_by_parse_type_proto(parameter.type))
  91. )
  92. node.add_attr(attr)
  93. self._cache_node(node)
  94. logger.debug("Foreach graph proto parameters, node id: %s, node name: %s, "
  95. "node def name: %s", node.node_id, node.name, parameter.name)
  96. def _parse_consts(self, consts):
  97. """
  98. Parse `anf_ir_pb2.NameValueProto` object, and create a const node.
  99. Args:
  100. consts (list[anf_ir_pb2.NameValueProto]): Refer to `anf_ir_pb2.NameValueProto` object.
  101. """
  102. logger.debug("Start to parse consts from proto.")
  103. for const in consts:
  104. if not const.key:
  105. logger.warning("Finding a const with an empty key will not save it.")
  106. continue
  107. node = Node(name=const.key, node_id=const.key)
  108. node.type = NodeTypeEnum.CONST.value
  109. if const.value.ByteSize() > self.MAX_NODE_ATTRIBUTE_VALUE_BYTES:
  110. node.add_attr({const.key: 'dtype: ' + DataType.Name(const.value.dtype)})
  111. else:
  112. node.add_attr({const.key: str(const.value)})
  113. if const.value.dtype == DataType.DT_TENSOR:
  114. shape = list(const.value.tensor_val.dims)
  115. node.output_shape.append(shape)
  116. if const.value.tensor_val.HasField('data_type'):
  117. node.elem_types.append(DataType.Name(const.value.tensor_val.data_type))
  118. else:
  119. node.elem_types.append(DataType.Name(const.value.dtype))
  120. # dim is zero
  121. node.output_shape.append([])
  122. node.output_nums = len(node.output_shape)
  123. self._cache_node(node)
  124. def _get_shape_by_parse_type_proto(self, type_proto):
  125. """
  126. Parse proto's `message TypeProto` to get shape information.
  127. Args:
  128. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  129. Returns:
  130. list, a list of shape.
  131. """
  132. shapes = []
  133. if type_proto.HasField('data_type'):
  134. if type_proto.data_type != DataType.DT_TENSOR and \
  135. type_proto.data_type != DataType.DT_TUPLE:
  136. # Append an empty list as a placeholder
  137. # for the convenience of output number calculation.
  138. shapes.append([])
  139. return shapes
  140. if type_proto.HasField('tensor_type'):
  141. tensor_type = type_proto.tensor_type
  142. tensor_shape_proto = tensor_type.shape
  143. shape = [dim.size for dim in tensor_shape_proto.dim]
  144. shapes.append(shape)
  145. if type_proto.HasField('sequence_type'):
  146. for elem_type in type_proto.sequence_type.elem_types:
  147. shapes.extend(self._get_shape_by_parse_type_proto(elem_type))
  148. return shapes
  149. def _get_data_type_by_parse_type_proto(self, type_proto, node):
  150. """
  151. Get data type by parse type proto object.
  152. The name of the DataType, refer to `anf_ir_pb2.DataType` object.
  153. If data type is tensor or tuple, the data name we return is `data_type[element_type, element_type]`.
  154. Args:
  155. type_proto (anf_ir_pb2.TypeProto): Refer to anf_ir_pb2.TypeProto.
  156. Returns:
  157. str, the data type.
  158. """
  159. data_type_name = self._get_data_type_name_by_value(type_proto, type_proto.data_type, field_name='data_type')
  160. if type_proto.data_type == DataType.DT_TENSOR:
  161. tensor_type_proto = type_proto.tensor_type
  162. value = type_proto.tensor_type.elem_type
  163. elem_type_name = self._get_data_type_name_by_value(tensor_type_proto, value, field_name='elem_type')
  164. node.elem_types.append(elem_type_name)
  165. return f'{data_type_name}[{elem_type_name}]'
  166. if type_proto.data_type == DataType.DT_TUPLE:
  167. data_types = []
  168. for elem_type in type_proto.sequence_type.elem_types:
  169. data_types.append(self._get_data_type_by_parse_type_proto(elem_type, node))
  170. return f'{data_type_name}{str(data_types)}'
  171. node.elem_types.append(data_type_name)
  172. return data_type_name
  173. def _parse_inputs(self, input_protos, node):
  174. """
  175. Parse `anf_ir_pb2.InputProto` object.
  176. Args:
  177. input_protos (list[anf_ir_pb2.InputProto]): Refer to `anf_ir_pb2.InputProto` object.
  178. node (Node): Refer to `Node` object, it is used to log message and update input.
  179. """
  180. for input_proto in input_protos:
  181. if not input_proto.name:
  182. logger.warning("The name in input proto of node(%s) is empty, will ignore.", node.name)
  183. continue
  184. edge_type = EdgeTypeEnum.DATA.value if not input_proto.type else EdgeTypeEnum.CONTROL.value
  185. # Notice:
  186. # 1. The name in the input proto is the node id of the Node object.
  187. # 2. In the current step, the shape of source node cannot be obtained,
  188. # so it is set to empty list by default, and the next step will update it.
  189. # 3. Same with scope, set the default value first.
  190. input_attr = {
  191. "shape": [],
  192. "edge_type": edge_type,
  193. "independent_layout": False,
  194. 'data_type': ''
  195. }
  196. node.add_inputs(src_name=input_proto.name, input_attr=input_attr)
  197. def _parse_attributes(self, attributes, node):
  198. """
  199. Parse `anf_ir_pb2.AttributeProto` object., and Filters large attribute values.
  200. Args:
  201. attributes (list[anf_ir_pb2.AttributeProto]): Refer to `anf_ir_pb2.AttributeProto` object.
  202. node (Node): Refer to `Node` object, it is used to log message and update attr.
  203. """
  204. for attr in attributes:
  205. if attr.value.ByteSize() > self.MAX_NODE_ATTRIBUTE_VALUE_BYTES:
  206. message = f"The attribute value of node({node.name}) " \
  207. f"is over {self.MAX_NODE_ATTRIBUTE_VALUE_BYTES} Bytes, will ignore."
  208. logger.warning(message)
  209. continue
  210. node.add_attr({attr.name: str(attr.value)})
  211. def _update_input_after_create_node(self):
  212. """Update the input of node after create node."""
  213. for node in self._normal_node_map.values():
  214. for src_node_id, input_attr in dict(node.inputs).items():
  215. node.delete_inputs(src_node_id)
  216. if not self._is_node_exist(node_id=src_node_id):
  217. message = f"The input node could not be found by node id({src_node_id}) " \
  218. f"while updating the input of the node({node})"
  219. logger.warning(message)
  220. continue
  221. src_node = self._get_normal_node(node_id=src_node_id)
  222. input_attr['shape'] = src_node.output_shape
  223. input_attr['data_type'] = src_node.output_data_type
  224. node.add_inputs(src_name=src_node.name, input_attr=input_attr)
  225. def _update_output_after_create_node(self):
  226. """Update the output of node after create node."""
  227. # Constants and parameter should not exist for input and output.
  228. filtered_node = {NodeTypeEnum.CONST.value, NodeTypeEnum.PARAMETER.value}
  229. for node in self._normal_node_map.values():
  230. for src_name, input_attr in node.inputs.items():
  231. src_node = self._get_normal_node(node_name=src_name)
  232. if src_node.type in filtered_node:
  233. continue
  234. src_node.add_outputs(node.name, input_attr)
  235. @staticmethod
  236. def _get_data_type_name_by_value(data_type, value, field_name='data_type'):
  237. """Get the data type name by the enum value, data_type refer to `DataType` object."""
  238. return data_type.DESCRIPTOR.fields_by_name[field_name].enum_type.values_by_number[value].name