From 35e3bee5aade6e5b80790a21032c086dd99bdefd Mon Sep 17 00:00:00 2001 From: yelihua Date: Sat, 19 Sep 2020 10:55:44 +0800 Subject: [PATCH] fix the bug for update watchpoint --- mindinsight/debugger/common/utils.py | 2 +- mindinsight/debugger/debugger_server.py | 26 +++++-------------- .../debugger/stream_cache/debugger_graph.py | 16 ++++-------- .../debugger/stream_cache/watchpoint.py | 26 ++++++++++++------- .../debugger/stream_handler/graph_handler.py | 14 +++++++--- .../debugger/stream_handler/tensor_handler.py | 5 ++-- 6 files changed, 43 insertions(+), 46 deletions(-) diff --git a/mindinsight/debugger/common/utils.py b/mindinsight/debugger/common/utils.py index df776f4f..fff96e2a 100644 --- a/mindinsight/debugger/common/utils.py +++ b/mindinsight/debugger/common/utils.py @@ -18,8 +18,8 @@ from collections import namedtuple import numpy as np +from mindinsight.datavisual.data_transform.graph import NodeTypeEnum from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply -from mindinsight.debugger.stream_cache.debugger_graph import NodeTypeEnum # translate the MindSpore type to numpy type. NUMPY_TYPE_MAP = { diff --git a/mindinsight/debugger/debugger_server.py b/mindinsight/debugger/debugger_server.py index 77066dc8..a5a00bb7 100644 --- a/mindinsight/debugger/debugger_server.py +++ b/mindinsight/debugger/debugger_server.py @@ -536,13 +536,11 @@ class DebuggerServer: node_infos = [] for node_name in node_names: node_type = graph_stream.get_node_type(node_name) - # optimizer later if node_type == NodeTypeEnum.AGGREGATION_SCOPE.value: - sub_nodes = graph_stream.get_nodes(node_name) + sub_nodes = graph_stream.get_nodes_by_scope(node_name) sub_infos = [NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type) for node in sub_nodes] node_infos.extend(sub_infos) - continue full_name = graph_stream.get_full_name(node_name) node_infos.append(NodeBasicInfo(name=node_name, full_name=full_name, type=node_type)) return node_infos @@ -615,17 +613,6 @@ class DebuggerServer: return {'metadata': {'state': current_state}} - def _validate_node_type(self, node_name): - """Check the node type in node control.""" - if not node_name: - return - node_type = self.cache_store.get_stream_handler(Streams.GRAPH).get_node_type(node_name) - unsupported_types = [item.value for item in list(NodeTypeEnum)] - if node_type in unsupported_types: - log.error("Invalid node type. %s", node_name) - raise DebuggerParamValueError(f"The type of node {node_name} is unsupported for " - "continue to command.") - def _construct_run_event(self, params): """ Construct run cmd from input control params. @@ -639,7 +626,7 @@ class DebuggerServer: - steps (int): Specify the steps that training should run. Used when `level` is `step`. - - full_name (str): Specify the name of the node. Used when `level` is `node`. + - name (str): Specify the name of the node. Used when `level` is `node`. Returns: EventReply, control event with run command. @@ -652,10 +639,11 @@ class DebuggerServer: steps = 1 run_cmd = RunCMD(run_level='step', run_steps=steps) elif level == 'node': - self._validate_node_type(params.get('name')) - name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name( - params['name']) - if not name: + name = params.get('name') + if name: + self._validate_leaf_name(name) + name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name(name) + else: name = '' run_cmd = RunCMD(run_level='node', node_name=name) else: diff --git a/mindinsight/debugger/stream_cache/debugger_graph.py b/mindinsight/debugger/stream_cache/debugger_graph.py index b290f1dd..af2abb74 100644 --- a/mindinsight/debugger/stream_cache/debugger_graph.py +++ b/mindinsight/debugger/stream_cache/debugger_graph.py @@ -16,7 +16,6 @@ from collections import deque from mindinsight.datavisual.data_transform.graph.msgraph import MSGraph -from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum from mindinsight.debugger.common.exceptions.exceptions import \ DebuggerNodeNotInGraphError, DebuggerParamValueError from mindinsight.debugger.common.log import logger as log @@ -90,14 +89,14 @@ class DebuggerGraph(MSGraph): def search_nodes_by_pattern(self, pattern): """ - Search node names by a given pattern. + Search node by a given pattern. Args: pattern (Union[str, None]): The pattern of the node to search, if None, return all node names. Returns: - list[(str, str)], a list of tuple (node name, node type). + list[Node], a list of node. """ if pattern is not None: pattern = pattern.lower() @@ -106,7 +105,7 @@ class DebuggerGraph(MSGraph): if pattern in name.lower() ] else: - searched_nodes = [node for name, node in self._leaf_nodes.items()] + searched_nodes = [node for _, node in self._leaf_nodes.items()] return searched_nodes def _build_node_tree(self, node_name, node_type): @@ -147,13 +146,8 @@ class DebuggerGraph(MSGraph): if node_name and not self.exist_node(name=node_name): raise DebuggerNodeNotInGraphError(node_name=node_name) - node = self._leaf_nodes.get(node_name) - if node is not None: - node_type = node.type - else: - node_type = NodeTypeEnum.NAME_SCOPE.value - - return node_type + node = self._normal_node_map.get(node_name) + return node.type def get_tensor_history(self, node_name, depth=0): """ diff --git a/mindinsight/debugger/stream_cache/watchpoint.py b/mindinsight/debugger/stream_cache/watchpoint.py index fb267e56..206d78fb 100644 --- a/mindinsight/debugger/stream_cache/watchpoint.py +++ b/mindinsight/debugger/stream_cache/watchpoint.py @@ -71,17 +71,21 @@ class WatchNodeTree: """The property of watch status about current node.""" return self._watch_status - def enable_watch_status(self): - """The property of watch status about current node.""" - self._watch_status = WatchNodeTree.TOTAL_WATCH + def update_metadata(self, node_type, full_name, watch_status): + """Update the metadata for watched node.""" + self._full_name = full_name + self._node_type = self._translate_node_type(node_type) + self._watch_status = watch_status @staticmethod def _translate_node_type(node_type): """Translate node type to watch node type.""" - if not node_type or node_type == NodeTypeEnum.NAME_SCOPE.value or \ - node_type == NodeTypeEnum.AGGREGATION_SCOPE.value: - return 'scope' - return 'leaf' + flag = node_type + if not node_type or node_type == NodeTypeEnum.NAME_SCOPE.value: + flag = 'scope' + elif node_type != NodeTypeEnum.AGGREGATION_SCOPE.value: + flag = 'leaf' + return flag def get(self, sub_name): """Get sub node.""" @@ -104,10 +108,11 @@ class WatchNodeTree: log.debug("Add node %s with type: %s, full_name: %s", node_name, node_type, full_name) scope_names = node_name.split('/', 1) if len(scope_names) == 1: - if not self.get(node_name): + target_node = self.get(node_name) + if not target_node: self.add(node_name, node_type, full_name, watch_status=WatchNodeTree.TOTAL_WATCH) else: - self.get(node_name).enable_watch_status() + target_node.update_metadata(node_type, full_name, WatchNodeTree.TOTAL_WATCH) return scope_name, sub_names = scope_names @@ -232,7 +237,8 @@ class Watchpoint: cur_watch_node (WatchNodeTree): The current watch node. watch_node_list (list[WatchNodeTree]): The list of total watched node. """ - if cur_watch_node.watch_status == WatchNodeTree.TOTAL_WATCH: + if cur_watch_node.watch_status == WatchNodeTree.TOTAL_WATCH and \ + cur_watch_node.node_type != NodeTypeEnum.AGGREGATION_SCOPE.value: watch_node_list.append(cur_watch_node) return for _, watch_node in cur_watch_node.get_children(): diff --git a/mindinsight/debugger/stream_handler/graph_handler.py b/mindinsight/debugger/stream_handler/graph_handler.py index 8a54480e..5e716805 100644 --- a/mindinsight/debugger/stream_handler/graph_handler.py +++ b/mindinsight/debugger/stream_handler/graph_handler.py @@ -143,9 +143,17 @@ class GraphHandler(StreamHandlerBase): return {'nodes': nodes} - def get_node_names(self, pattern=None): - """Get graph nodes according to pattern.""" - return self._graph.search_nodes_by_pattern(pattern) + def get_nodes_by_scope(self, scope_name): + """ + Get node by a given scope name. + + Args: + scope_name (str): The name of scope. + + Returns: + list[Node], a list of node. + """ + return self._graph.search_nodes_by_pattern(scope_name) def get_searched_node_list(self): """Get searched node list.""" diff --git a/mindinsight/debugger/stream_handler/tensor_handler.py b/mindinsight/debugger/stream_handler/tensor_handler.py index 4666d9c9..4bbfc814 100644 --- a/mindinsight/debugger/stream_handler/tensor_handler.py +++ b/mindinsight/debugger/stream_handler/tensor_handler.py @@ -226,9 +226,10 @@ class TensorHandler(StreamHandlerBase): def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type): """Update has_prev_step field in tensor info.""" flag = None - if node_type == NodeTypeEnum.PARAMETER.value: + cur_tensor_value = bool(tensor_info and tensor_info.get('value') is not None) + if node_type == NodeTypeEnum.PARAMETER.value and cur_tensor_value: flag = self._get_prev_tensor_value_status(tensor_name) - if flag and tensor_info: + if flag: tensor_info['has_prev_step'] = True return flag