You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_multigraph.py 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the basic graph."""
  16. import copy
  17. from mindinsight.debugger.common.log import LOGGER as log
  18. from mindinsight.datavisual.data_transform.graph.node import Node, NodeTypeEnum
  19. from .debugger_graph import DebuggerGraph
  20. class DebuggerMultiGraph(DebuggerGraph):
  21. """The `DebuggerMultiGraph` object provides interfaces to describe a debugger multigraph."""
  22. def add_graph(self, graph_dict):
  23. """
  24. add graphs to DebuggerMultiGraph
  25. Args:
  26. graph_dict (dict): The <graph_name, graph_object> dict.
  27. """
  28. if len(graph_dict) == 1:
  29. graph = list(graph_dict.values())[0]
  30. self._normal_node_map = graph.normal_node_map
  31. self._node_id_map_name = graph.node_id_map_name
  32. self._const_node_temp_cache = graph.const_node_temp_cache
  33. self._parameter_node_temp_cache = graph.parameter_node_temp_cache
  34. self._leaf_nodes = graph.leaf_nodes
  35. self._full_name_map_name = graph.full_name_map_name
  36. else:
  37. for graph_name, graph in graph_dict.items():
  38. log.debug("add graph %s into whole graph.", graph_name)
  39. # add nodes
  40. normal_nodes = copy.deepcopy(graph.normal_node_map)
  41. for _, node_obj in normal_nodes.items():
  42. pre_scope = graph_name + "/"
  43. node_obj.name = pre_scope + node_obj.name
  44. node_obj.full_name = pre_scope + node_obj.full_name
  45. if node_obj.scope:
  46. node_obj.scope = pre_scope + node_obj.scope
  47. else:
  48. node_obj.scope = graph_name
  49. # update inputs
  50. old_inputs = copy.deepcopy(node_obj.inputs)
  51. for src_name, input_attr in old_inputs.items():
  52. new_src_name = graph_name + "/" + src_name
  53. node_obj.add_inputs(new_src_name, input_attr)
  54. node_obj.delete_inputs(src_name)
  55. # update_outputs
  56. old_outputs = copy.deepcopy(node_obj.outputs)
  57. for dst_name, output_attr in old_outputs.items():
  58. new_dst_name = graph_name + "/" + dst_name
  59. node_obj.add_outputs(new_dst_name, output_attr)
  60. node_obj.delete_outputs(dst_name)
  61. self._cache_node(node_obj)
  62. # add graph_node
  63. node = Node(name=graph_name, node_id=graph_name)
  64. node.type = NodeTypeEnum.NAME_SCOPE.value
  65. node.subnode_count = len(graph.list_node_by_scope())
  66. self._cache_node(node)
  67. self._leaf_nodes = self._get_leaf_nodes()
  68. self._full_name_map_name = self._get_leaf_node_full_name_map()
  69. log.info(
  70. "Build multi_graph end, all node count: %s, const count: %s, parameter count: %s.",
  71. self.normal_node_count, len(self._const_node_temp_cache),
  72. len(self._parameter_node_temp_cache))