You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph.py 24 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. This file is used to define the basic graph.
  17. """
  18. import time
  19. from enum import Enum
  20. from collections import defaultdict
  21. from mindinsight.datavisual.common.exceptions import NodeNotInGraphError
  22. from mindinsight.datavisual.common.log import logger
  23. from mindinsight.utils.exceptions import ParamMissError
  24. from mindinsight.utils.exceptions import ParamValueError
  25. from .node import NodeTypeEnum
  26. from .node import Node
  27. class EdgeTypeEnum(Enum):
  28. """Node edge type enum."""
  29. CONTROL = 'control'
  30. DATA = 'data'
  31. class Graph:
  32. """The `Graph` object is used to describe a graph file."""
  33. # Limit the size of a single attribute value per node to avoid storing too much data
  34. MAX_NODE_ATTRIBUTE_VALUE_BYTES = 1024
  35. # In the same scope, the number of children of the same type exceeds this threshold, and we will combine them.
  36. MIN_GROUP_NODE_COUNT = 5
  37. def __init__(self):
  38. # Used to cache all nodes, and the key is node name, value is `Node` object.
  39. self._normal_node_map = {}
  40. self._node_id_map_name = {}
  41. # The additional caching of Const and Parameter is to handle the Const
  42. # and Parameter nodes separately later.
  43. self._const_node_temp_cache = {}
  44. self._parameter_node_temp_cache = {}
  45. def build_graph(self, proto_data):
  46. """This method is used to build the graph."""
  47. logger.info("Start to build graph")
  48. start_time = time.time()
  49. # Notice:
  50. # The following methods are interdependent and cannot be switched at will.
  51. self._parse_data(proto_data)
  52. self._add_variable_nodes(NodeTypeEnum.PARAMETER.value)
  53. self._build_aggregation_scope_nodes()
  54. self._process_independent_layout()
  55. self._build_name_scope_nodes()
  56. # Since const nodes are not aggregated, adding them at the end can save a lot of computation.
  57. self._add_variable_nodes(NodeTypeEnum.CONST.value)
  58. self._calc_subnode_count()
  59. precision = 6
  60. time_consuming = round(time.time() - start_time, precision)
  61. logger.info("Build graph end, all node count: %s, const count: %s, parameter count: %s, time-consuming: %s s.",
  62. self.normal_node_count, len(self._const_node_temp_cache),
  63. len(self._parameter_node_temp_cache), time_consuming)
  64. def exist_node(self, name):
  65. """
  66. Check node exist in graph.
  67. Args:
  68. name (str): The node name.
  69. Returns:
  70. bool, if node exists, will return True.
  71. """
  72. if name is None:
  73. return False
  74. return self._is_node_exist(node_name=name)
  75. def list_node_by_scope(self, scope=None):
  76. """
  77. List nodes by the scope of nodes. The scope of a node is the same as its parent node name.
  78. Args:
  79. scope (str): A scope of nodes.
  80. Returns:
  81. list[dict], a list object contain `Node` object.
  82. """
  83. scope = "" if scope is None else scope
  84. nodes = []
  85. for node in self._normal_node_map.values():
  86. if node.scope == scope:
  87. nodes.append(node.to_dict())
  88. return nodes
  89. def search_node_names(self, content, offset, limit):
  90. """
  91. Search node names by content.
  92. Args:
  93. content (Union[str, None]): This content can be the key content of the node to search,
  94. if None, will get all node names.
  95. offset (int): An offset for page. Ex, offset is 0, mean current page is 1.
  96. limit (int): An offset for page. Ex, offset is 0, mean current page is 1.
  97. Returns:
  98. list[str], a list of node names.
  99. """
  100. if content is not None:
  101. content = content.lower()
  102. catch_names = [name for name in self._normal_node_map if content in name.lower()]
  103. else:
  104. catch_names = list(self._normal_node_map)
  105. catch_names = sorted(catch_names)
  106. real_offset = offset * limit
  107. return catch_names[real_offset:real_offset+limit]
  108. def search_single_node(self, node_name):
  109. """
  110. Search node, and return every layer nodes until this node.
  111. Args:
  112. node_name (str): The name of node.
  113. Returns:
  114. dict, a dict object, format is :
  115. item_object = {'nodes': [<Node object>],
  116. 'scope_name': '<Node scope>',
  117. 'children': {<item_object>}}
  118. """
  119. if node_name and not self.exist_node(name=node_name):
  120. raise NodeNotInGraphError(node_name=node_name)
  121. response = {}
  122. nodes = self.list_node_by_scope()
  123. response.update({
  124. 'nodes': nodes,
  125. 'scope_name': '',
  126. 'children': {}
  127. })
  128. children = response['children']
  129. index = node_name.find('/')
  130. while index != -1:
  131. scope = node_name[:index]
  132. nodes = self.list_node_by_scope(scope)
  133. children.update({
  134. 'nodes': nodes,
  135. 'scope_name': scope,
  136. 'children': {}
  137. })
  138. children = children['children']
  139. index = node_name.find('/', index+1)
  140. return response
  141. def _parse_data(self, proto_data):
  142. """
  143. This method will parse the data and create basic nodes to store in the cache.
  144. The graph is then built based on the cache.
  145. """
  146. raise NotImplementedError("Before you can build a graph, you need to parse the data.")
  147. def _build_name_scope_nodes(self):
  148. """
  149. Build name scope node by every node name.
  150. We create the name scope node by the slash('/') in the node name.
  151. For example, if a node name is "Default/add", we generate a scope named 'Default' based on slash('/') and
  152. create a name scope node named 'Default'.
  153. """
  154. logger.info("Start to build name scope nodes.")
  155. scope_node_map = {}
  156. for name, node in self._normal_node_map.items():
  157. index = name.find('/')
  158. pre_index = None
  159. while index > 0:
  160. scope = name[:index]
  161. scope_node = scope_node_map.get(scope)
  162. if scope_node is None:
  163. if self._is_node_exist(node_name=scope):
  164. exist_node = self._get_normal_node(node_name=scope)
  165. if exist_node.type == NodeTypeEnum.AGGREGATION_SCOPE.value:
  166. # This scope is aggregation scope, so we don't have to do anything.
  167. pre_index = index
  168. index = name.find('/', pre_index + 1)
  169. continue
  170. # We find a node name that conflicts with the current scope and rename the node
  171. self._update_conflict_node(conflict_name=scope)
  172. # We create a node for current scope.
  173. scope_node = Node(scope, node_id=scope)
  174. scope_node.type = NodeTypeEnum.NAME_SCOPE.value
  175. scope_node.scope = '' if pre_index is None else name[:pre_index]
  176. scope_node_map.update({scope_node.name: scope_node})
  177. # Inherit input and output from sub nodes.
  178. self._inherit_input_output_from_subnode(scope_node, subnode_list=[node])
  179. pre_index = index
  180. index = name.find('/', pre_index+1)
  181. # Cache all the scope node to normal node dict
  182. for node in scope_node_map.values():
  183. self._cache_node(node)
  184. def _update_conflict_node(self, conflict_name):
  185. conflict_node = self._get_normal_node(node_name=conflict_name)
  186. base_name = conflict_name.split('/')[-1]
  187. new_name = Node.create_node_name(scope=conflict_node.scope, base_name=f'({base_name})')
  188. self._update_node_name_of_cache(conflict_node, new_name, update_parent=True)
  189. def _inherit_input_output_from_subnode(self, parent_node, subnode_list, filtered_type=None):
  190. """
  191. Adds the input and output of all direct child nodes to the current node.
  192. Args:
  193. parent_node (Node): The nodes that inherit the input and output of the child nodes.
  194. subnode_list (list[Node]): A list of child nodes that are inherited from the input and output.
  195. filtered_type (set(str)): Filter some input and output that do not require inheritance
  196. based on the node type. Default is filter const node.
  197. Note:
  198. - Only the inputs and outputs of the external scope are inherited.
  199. - Before add_const_node method, if the input is a const,
  200. the scope of the const node is not startswith the name of parent node.
  201. So in this scenario, we need to filter the const nodes.
  202. """
  203. filtered_type = {NodeTypeEnum.CONST.value} if filtered_type is None else filtered_type
  204. for method in ['input', 'output', 'proxy_input', 'proxy_output']:
  205. for node in subnode_list:
  206. for item_name, item_attr in getattr(node, method).items():
  207. target_node = self._get_normal_node(node_name=item_name)
  208. if item_name.startswith(f'{parent_node.name}/'):
  209. # Own scope, ignore
  210. continue
  211. if target_node.type in filtered_type:
  212. continue
  213. getattr(parent_node, f'add_{method}')(item_name, item_attr)
  214. def _build_aggregation_scope_nodes(self):
  215. """
  216. Under the same scope, the number of nodes of the same type will be aggregated after exceeding the set threshold.
  217. Note:
  218. The threshold value refers to the `MIN_GROUP_NODE_COUNT`.
  219. """
  220. logger.info("Start to build aggregation scope nodes.")
  221. group_node_map, filtered_group_names = self._find_group_nodes()
  222. # create merge scope nodes
  223. aggregation_scope_node_map = {}
  224. for i, group_name in enumerate(filtered_group_names):
  225. slash_index = group_name.rfind('/')
  226. if slash_index != -1:
  227. scope, op_type = group_name[:slash_index], group_name[slash_index+1:]
  228. else:
  229. scope, op_type = '', group_name
  230. count = len(group_node_map.get(group_name))
  231. aggregation_node_name = Node.create_node_name(scope=scope, base_name=f'{op_type}[{count}]_{i}')
  232. aggregation_scope_node = Node(name=aggregation_node_name, node_id=aggregation_node_name)
  233. aggregation_scope_node.subnode_count = count
  234. aggregation_scope_node.scope = scope
  235. aggregation_scope_node.type = NodeTypeEnum.AGGREGATION_SCOPE.value
  236. # Update the name and scope of all children nodes
  237. for node in group_node_map[group_name]:
  238. base_name = node.name.split('/')[-1]
  239. new_name = Node.create_node_name(scope=aggregation_node_name, base_name=base_name)
  240. node.scope = aggregation_node_name
  241. # Since the name scope has not been created, there is no need to update the parent node.
  242. self._update_node_name_of_cache(node, new_name, update_parent=False)
  243. # Cache this node
  244. self._cache_node(aggregation_scope_node)
  245. aggregation_scope_node_map.update({group_name: aggregation_scope_node})
  246. # Adds the input and output of all direct child nodes to the current node.
  247. for group_name, node in aggregation_scope_node_map.items():
  248. self._inherit_input_output_from_subnode(node, group_node_map[group_name])
  249. def _find_group_nodes(self):
  250. """
  251. Find nodes that can be grouped into a group.
  252. For direct child nodes in a scope, we divide them into multiple groups by node type.
  253. However, we will exclude several types of child nodes,
  254. because these types of nodes are not operational nodes.
  255. """
  256. exclude_types = {
  257. NodeTypeEnum.CONST.value,
  258. NodeTypeEnum.NAME_SCOPE.value,
  259. }
  260. group_node_map = defaultdict(list)
  261. for node in self._normal_node_map.values():
  262. if node.type in exclude_types:
  263. continue
  264. group_name = Node.create_node_name(scope=node.scope, base_name=node.type)
  265. group_node_map[group_name].append(node)
  266. # filter can group scope.
  267. filtered_group_names = []
  268. for name, nodes in group_node_map.items():
  269. if len(nodes) < self.MIN_GROUP_NODE_COUNT:
  270. continue
  271. filtered_group_names.append(name)
  272. return group_node_map, filtered_group_names
  273. def _add_variable_nodes(self, node_type):
  274. """
  275. We create the Const nodes or Parameter nodes in this method.
  276. Args:
  277. node_type (str): Decide which type of node to add.
  278. Optional is `NodeTypeEnum.CONST.value` and `NodeTypeEnum.PARAMETER.value`.
  279. Note:
  280. This method relies on the presence of data in the const cache or parameter cache.
  281. """
  282. logger.info("Start to add %s nodes to each scope in graph.", node_type)
  283. node_map = {}
  284. for node in self._normal_node_map.values():
  285. for src_name, input_attr in dict(node.input).items():
  286. if node_type == NodeTypeEnum.CONST.value and not self._const_node_temp_cache.get(src_name):
  287. continue
  288. if node_type == NodeTypeEnum.PARAMETER.value and not self._parameter_node_temp_cache.get(src_name):
  289. continue
  290. variable_name = Node.create_node_name(scope=node.scope, base_name=src_name)
  291. if node_map.get(variable_name):
  292. # There is no need to create the node repeatedly
  293. variable_node = node_map.get(variable_name)
  294. else:
  295. cache_node = self._get_normal_node(node_name=src_name)
  296. variable_node = Node(name=variable_name, node_id=variable_name)
  297. Node.copy_node_without_input_output(cache_node, variable_node)
  298. variable_node.scope = node.scope
  299. variable_node.add_output(dst_name=node.name, output_attr=input_attr)
  300. node_map.update({variable_name: variable_node})
  301. node.delete_input(src_name)
  302. node.add_input(variable_name, input_attr)
  303. for node in node_map.values():
  304. self._cache_node(node)
  305. # Remove nodes that are not used in the cache.
  306. if node_type == NodeTypeEnum.CONST.value:
  307. unused_names = set(self._const_node_temp_cache) - set(node_map)
  308. elif node_type == NodeTypeEnum.PARAMETER.value:
  309. unused_names = set(self._parameter_node_temp_cache) - set(node_map)
  310. else:
  311. raise ParamValueError("The node type should be const or parameter.")
  312. self._delete_nodes_of_cache(unused_names)
  313. def _calc_subnode_count(self):
  314. """Calc all the direct sub node count."""
  315. subnode_count_map = defaultdict(int)
  316. for node in self._normal_node_map.values():
  317. if not node.scope:
  318. continue
  319. if not self._is_node_exist(node_name=node.scope):
  320. logger.warning("Can not find a scope node by the given name(%s), "
  321. "the name scope nodes may not have been created.", node.scope)
  322. continue
  323. subnode_count_map[node.scope] = subnode_count_map[node.scope] + 1
  324. for name, count in subnode_count_map.items():
  325. node = self._get_normal_node(node_name=name)
  326. node.subnode_count = count
  327. def _get_normal_node(self, node_id=None, node_name=None):
  328. """Query node by node id or node name."""
  329. if node_id is not None:
  330. name = self._node_id_map_name.get(node_id)
  331. node = self._normal_node_map.get(name)
  332. return node
  333. if node_name is not None:
  334. return self._normal_node_map.get(node_name)
  335. raise ParamMissError('Method requires an argument that is not None.')
  336. def _is_node_exist(self, node_id=None, node_name=None):
  337. """Check node is exist."""
  338. if node_id is not None:
  339. return bool(self._node_id_map_name.get(node_id))
  340. if node_name is not None:
  341. return bool(self._normal_node_map.get(node_name))
  342. raise ParamMissError('Method requires an argument that is not None.')
  343. @property
  344. def normal_node_count(self):
  345. """Get the normal node count."""
  346. return len(self._normal_node_map)
  347. def _cache_node(self, node):
  348. """Store the node in the cache."""
  349. # Notice:
  350. # The additional caching of Const and Parameter is to handle the Const and Parameter nodes separately later.
  351. if node.type == NodeTypeEnum.CONST.value:
  352. self._const_node_temp_cache.update({node.name: node})
  353. if node.type == NodeTypeEnum.PARAMETER.value:
  354. self._parameter_node_temp_cache.update({node.name: node})
  355. self._normal_node_map.update({node.name: node})
  356. self._node_id_map_name.update({node.node_id: node.name})
  357. def _delete_nodes_of_cache(self, node_names):
  358. """Delete node from cache."""
  359. logger.debug("These nodes will be removed from the cache, node names: %s.", str(node_names))
  360. for name in node_names:
  361. if self._parameter_node_temp_cache.get(name):
  362. self._parameter_node_temp_cache.pop(name)
  363. if self._const_node_temp_cache.get(name):
  364. self._const_node_temp_cache.pop(name)
  365. node = self._get_normal_node(node_name=name)
  366. self._normal_node_map.pop(name)
  367. self._node_id_map_name.pop(node.node_id)
  368. def _update_node_name_of_cache(self, node, new_name, update_parent=False):
  369. """
  370. Update a node name which is stored in cache.
  371. Args:
  372. node (Node): The node that will be renamed.
  373. new_name (str): The new name.
  374. update_parent (bool): Determines whether the input and output of the parent node need to be updated.
  375. """
  376. logger.debug('Update node name of cache, node(%s), new name is %s.', str(node), new_name)
  377. origin_name = node.name
  378. node.name = new_name
  379. # Find all nodes that need to modify the input and input
  380. update_node_map = {}
  381. for method in ['input', 'output', 'proxy_input', 'proxy_output']:
  382. for target_name in getattr(node, method):
  383. target_node = self._get_normal_node(node_name=target_name)
  384. if target_node is None:
  385. message = f"Node should not be None, name: {target_name}, {method}: {list(getattr(node, method))}."
  386. logger.error(message)
  387. continue
  388. update_node_map.update({target_name: target_node})
  389. if not update_parent:
  390. continue
  391. slash_index = target_name.find('/')
  392. while slash_index != -1:
  393. scope_name = target_name[:slash_index]
  394. slash_index = target_name.find('/', slash_index+1)
  395. if update_node_map.get(scope_name):
  396. continue
  397. scope_node = self._get_normal_node(node_name=scope_name)
  398. if scope_node is None:
  399. message = f"Can not find the scope node by scope name({scope_name}), " \
  400. f"may be this scope node has not been built."
  401. logger.debug(message)
  402. continue
  403. update_node_map.update({scope_name: scope_node})
  404. # Update the input and output of the nodes
  405. for target_node in update_node_map.values():
  406. for method in ['input', 'output', 'proxy_input', 'proxy_output']:
  407. attr_temp = getattr(target_node, method).get(origin_name)
  408. if attr_temp is None:
  409. # This method does not have this node, so it is skipped
  410. continue
  411. # Delete the old attribute and update new name to source node or destination node.
  412. getattr(target_node, f'delete_{method}')(origin_name)
  413. getattr(target_node, f'add_{method}')(new_name, attr_temp)
  414. # Delete the origin node in cache.
  415. self._delete_nodes_of_cache(node_names=[origin_name])
  416. self._cache_node(node)
  417. def _process_independent_layout(self):
  418. """Handle separate layout nodes."""
  419. independent_layout_node_map = {}
  420. for node in self._normal_node_map.values():
  421. base_name = node.name.split('/')[-1]
  422. if node.type == NodeTypeEnum.AGGREGATION_SCOPE.value and NodeTypeEnum.PARAMETER.value in base_name:
  423. independent_layout_node_map[node.name] = node
  424. # Find all sub nodes
  425. subnode_map = defaultdict(list)
  426. for node in self._normal_node_map.values():
  427. if independent_layout_node_map.get(node.scope):
  428. subnode_map[node.scope].append(node)
  429. # Notice:
  430. # The following processing is only done for the parameter node, other types of nodes are not processed.
  431. # Later, when you need to extend to other nodes, the code needs to be adjusted.
  432. for scope_node in independent_layout_node_map.values():
  433. scope_node.independent_layout = True
  434. method = 'output'
  435. for target_name, target_attr in dict(getattr(scope_node, method)).items():
  436. proxy_attr = dict(edge_type=target_attr['edge_type'])
  437. target_node = self._get_normal_node(node_name=target_name)
  438. getattr(target_node, 'add_proxy_input')(scope_node.name, proxy_attr)
  439. # Note:
  440. # If the source node and the destination node are not in the same scope,
  441. # the proxy node is presented as scope in order to simplify the flow of the display data.
  442. # For example, the data flow is parameter[5]_1 -> add[5]_1/add1
  443. # we create a scope proxy node(add[5]_1) for parameter[5]_1,
  444. # so there is a proxy data flow parameter[5]_1 -> add[5]_1 instead of parameter[5]_1 -> add[5]_1/add1.
  445. if target_node.scope == scope_node.scope:
  446. getattr(scope_node, f'add_proxy_{method}')(target_name, proxy_attr)
  447. else:
  448. target_scope_node = self._get_normal_node(node_name=target_node.scope)
  449. getattr(scope_node, f'add_proxy_{method}')(target_node.scope, proxy_attr)
  450. getattr(target_scope_node, 'add_proxy_input')(scope_node.name, proxy_attr)
  451. for subnode in subnode_map[scope_node.name]:
  452. subnode.independent_layout = True
  453. for target_name, target_attr in dict(getattr(subnode, method)).items():
  454. proxy_attr = dict(edge_type=target_attr['edge_type'])
  455. target_node = self._get_normal_node(node_name=target_name)
  456. if target_node.scope == scope_node.scope:
  457. getattr(subnode, f'add_proxy_{method}')(target_name, proxy_attr)
  458. else:
  459. getattr(subnode, f'add_proxy_{method}')(target_node.scope, proxy_attr)
  460. input_attr = getattr(target_node, 'input')[subnode.name]
  461. input_attr['independent_layout'] = True
  462. target_node.add_input(subnode.name, input_attr)