| @@ -18,8 +18,8 @@ from collections import namedtuple | |||
| import numpy as np | |||
| from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | |||
| from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply | |||
| from mindinsight.debugger.stream_cache.debugger_graph import NodeTypeEnum | |||
| # translate the MindSpore type to numpy type. | |||
| NUMPY_TYPE_MAP = { | |||
| @@ -536,13 +536,11 @@ class DebuggerServer: | |||
| node_infos = [] | |||
| for node_name in node_names: | |||
| node_type = graph_stream.get_node_type(node_name) | |||
| # optimizer later | |||
| if node_type == NodeTypeEnum.AGGREGATION_SCOPE.value: | |||
| sub_nodes = graph_stream.get_nodes(node_name) | |||
| sub_nodes = graph_stream.get_nodes_by_scope(node_name) | |||
| sub_infos = [NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type) | |||
| for node in sub_nodes] | |||
| node_infos.extend(sub_infos) | |||
| continue | |||
| full_name = graph_stream.get_full_name(node_name) | |||
| node_infos.append(NodeBasicInfo(name=node_name, full_name=full_name, type=node_type)) | |||
| return node_infos | |||
| @@ -615,17 +613,6 @@ class DebuggerServer: | |||
| return {'metadata': {'state': current_state}} | |||
| def _validate_node_type(self, node_name): | |||
| """Check the node type in node control.""" | |||
| if not node_name: | |||
| return | |||
| node_type = self.cache_store.get_stream_handler(Streams.GRAPH).get_node_type(node_name) | |||
| unsupported_types = [item.value for item in list(NodeTypeEnum)] | |||
| if node_type in unsupported_types: | |||
| log.error("Invalid node type. %s", node_name) | |||
| raise DebuggerParamValueError(f"The type of node {node_name} is unsupported for " | |||
| "continue to command.") | |||
| def _construct_run_event(self, params): | |||
| """ | |||
| Construct run cmd from input control params. | |||
| @@ -639,7 +626,7 @@ class DebuggerServer: | |||
| - steps (int): Specify the steps that training should run. | |||
| Used when `level` is `step`. | |||
| - full_name (str): Specify the name of the node. Used when `level` is `node`. | |||
| - name (str): Specify the name of the node. Used when `level` is `node`. | |||
| Returns: | |||
| EventReply, control event with run command. | |||
| @@ -652,10 +639,11 @@ class DebuggerServer: | |||
| steps = 1 | |||
| run_cmd = RunCMD(run_level='step', run_steps=steps) | |||
| elif level == 'node': | |||
| self._validate_node_type(params.get('name')) | |||
| name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name( | |||
| params['name']) | |||
| if not name: | |||
| name = params.get('name') | |||
| if name: | |||
| self._validate_leaf_name(name) | |||
| name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name(name) | |||
| else: | |||
| name = '' | |||
| run_cmd = RunCMD(run_level='node', node_name=name) | |||
| else: | |||
| @@ -16,7 +16,6 @@ | |||
| from collections import deque | |||
| from mindinsight.datavisual.data_transform.graph.msgraph import MSGraph | |||
| from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum | |||
| from mindinsight.debugger.common.exceptions.exceptions import \ | |||
| DebuggerNodeNotInGraphError, DebuggerParamValueError | |||
| from mindinsight.debugger.common.log import logger as log | |||
| @@ -90,14 +89,14 @@ class DebuggerGraph(MSGraph): | |||
| def search_nodes_by_pattern(self, pattern): | |||
| """ | |||
| Search node names by a given pattern. | |||
| Search node by a given pattern. | |||
| Args: | |||
| pattern (Union[str, None]): The pattern of the node to search, | |||
| if None, return all node names. | |||
| Returns: | |||
| list[(str, str)], a list of tuple (node name, node type). | |||
| list[Node], a list of node. | |||
| """ | |||
| if pattern is not None: | |||
| pattern = pattern.lower() | |||
| @@ -106,7 +105,7 @@ class DebuggerGraph(MSGraph): | |||
| if pattern in name.lower() | |||
| ] | |||
| else: | |||
| searched_nodes = [node for name, node in self._leaf_nodes.items()] | |||
| searched_nodes = [node for _, node in self._leaf_nodes.items()] | |||
| return searched_nodes | |||
| def _build_node_tree(self, node_name, node_type): | |||
| @@ -147,13 +146,8 @@ class DebuggerGraph(MSGraph): | |||
| if node_name and not self.exist_node(name=node_name): | |||
| raise DebuggerNodeNotInGraphError(node_name=node_name) | |||
| node = self._leaf_nodes.get(node_name) | |||
| if node is not None: | |||
| node_type = node.type | |||
| else: | |||
| node_type = NodeTypeEnum.NAME_SCOPE.value | |||
| return node_type | |||
| node = self._normal_node_map.get(node_name) | |||
| return node.type | |||
| def get_tensor_history(self, node_name, depth=0): | |||
| """ | |||
| @@ -71,17 +71,21 @@ class WatchNodeTree: | |||
| """The property of watch status about current node.""" | |||
| return self._watch_status | |||
| def enable_watch_status(self): | |||
| """The property of watch status about current node.""" | |||
| self._watch_status = WatchNodeTree.TOTAL_WATCH | |||
| def update_metadata(self, node_type, full_name, watch_status): | |||
| """Update the metadata for watched node.""" | |||
| self._full_name = full_name | |||
| self._node_type = self._translate_node_type(node_type) | |||
| self._watch_status = watch_status | |||
| @staticmethod | |||
| def _translate_node_type(node_type): | |||
| """Translate node type to watch node type.""" | |||
| if not node_type or node_type == NodeTypeEnum.NAME_SCOPE.value or \ | |||
| node_type == NodeTypeEnum.AGGREGATION_SCOPE.value: | |||
| return 'scope' | |||
| return 'leaf' | |||
| flag = node_type | |||
| if not node_type or node_type == NodeTypeEnum.NAME_SCOPE.value: | |||
| flag = 'scope' | |||
| elif node_type != NodeTypeEnum.AGGREGATION_SCOPE.value: | |||
| flag = 'leaf' | |||
| return flag | |||
| def get(self, sub_name): | |||
| """Get sub node.""" | |||
| @@ -104,10 +108,11 @@ class WatchNodeTree: | |||
| log.debug("Add node %s with type: %s, full_name: %s", node_name, node_type, full_name) | |||
| scope_names = node_name.split('/', 1) | |||
| if len(scope_names) == 1: | |||
| if not self.get(node_name): | |||
| target_node = self.get(node_name) | |||
| if not target_node: | |||
| self.add(node_name, node_type, full_name, watch_status=WatchNodeTree.TOTAL_WATCH) | |||
| else: | |||
| self.get(node_name).enable_watch_status() | |||
| target_node.update_metadata(node_type, full_name, WatchNodeTree.TOTAL_WATCH) | |||
| return | |||
| scope_name, sub_names = scope_names | |||
| @@ -232,7 +237,8 @@ class Watchpoint: | |||
| cur_watch_node (WatchNodeTree): The current watch node. | |||
| watch_node_list (list[WatchNodeTree]): The list of total watched node. | |||
| """ | |||
| if cur_watch_node.watch_status == WatchNodeTree.TOTAL_WATCH: | |||
| if cur_watch_node.watch_status == WatchNodeTree.TOTAL_WATCH and \ | |||
| cur_watch_node.node_type != NodeTypeEnum.AGGREGATION_SCOPE.value: | |||
| watch_node_list.append(cur_watch_node) | |||
| return | |||
| for _, watch_node in cur_watch_node.get_children(): | |||
| @@ -143,9 +143,17 @@ class GraphHandler(StreamHandlerBase): | |||
| return {'nodes': nodes} | |||
| def get_node_names(self, pattern=None): | |||
| """Get graph nodes according to pattern.""" | |||
| return self._graph.search_nodes_by_pattern(pattern) | |||
| def get_nodes_by_scope(self, scope_name): | |||
| """ | |||
| Get node by a given scope name. | |||
| Args: | |||
| scope_name (str): The name of scope. | |||
| Returns: | |||
| list[Node], a list of node. | |||
| """ | |||
| return self._graph.search_nodes_by_pattern(scope_name) | |||
| def get_searched_node_list(self): | |||
| """Get searched node list.""" | |||
| @@ -226,9 +226,10 @@ class TensorHandler(StreamHandlerBase): | |||
| def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type): | |||
| """Update has_prev_step field in tensor info.""" | |||
| flag = None | |||
| if node_type == NodeTypeEnum.PARAMETER.value: | |||
| cur_tensor_value = bool(tensor_info and tensor_info.get('value') is not None) | |||
| if node_type == NodeTypeEnum.PARAMETER.value and cur_tensor_value: | |||
| flag = self._get_prev_tensor_value_status(tensor_name) | |||
| if flag and tensor_info: | |||
| if flag: | |||
| tensor_info['has_prev_step'] = True | |||
| return flag | |||