| @@ -18,8 +18,8 @@ from collections import namedtuple | |||||
| import numpy as np | import numpy as np | ||||
| from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | |||||
| from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply | from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply | ||||
| from mindinsight.debugger.stream_cache.debugger_graph import NodeTypeEnum | |||||
| # translate the MindSpore type to numpy type. | # translate the MindSpore type to numpy type. | ||||
| NUMPY_TYPE_MAP = { | NUMPY_TYPE_MAP = { | ||||
| @@ -536,13 +536,11 @@ class DebuggerServer: | |||||
| node_infos = [] | node_infos = [] | ||||
| for node_name in node_names: | for node_name in node_names: | ||||
| node_type = graph_stream.get_node_type(node_name) | node_type = graph_stream.get_node_type(node_name) | ||||
| # optimizer later | |||||
| if node_type == NodeTypeEnum.AGGREGATION_SCOPE.value: | if node_type == NodeTypeEnum.AGGREGATION_SCOPE.value: | ||||
| sub_nodes = graph_stream.get_nodes(node_name) | |||||
| sub_nodes = graph_stream.get_nodes_by_scope(node_name) | |||||
| sub_infos = [NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type) | sub_infos = [NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type) | ||||
| for node in sub_nodes] | for node in sub_nodes] | ||||
| node_infos.extend(sub_infos) | node_infos.extend(sub_infos) | ||||
| continue | |||||
| full_name = graph_stream.get_full_name(node_name) | full_name = graph_stream.get_full_name(node_name) | ||||
| node_infos.append(NodeBasicInfo(name=node_name, full_name=full_name, type=node_type)) | node_infos.append(NodeBasicInfo(name=node_name, full_name=full_name, type=node_type)) | ||||
| return node_infos | return node_infos | ||||
| @@ -615,17 +613,6 @@ class DebuggerServer: | |||||
| return {'metadata': {'state': current_state}} | return {'metadata': {'state': current_state}} | ||||
| def _validate_node_type(self, node_name): | |||||
| """Check the node type in node control.""" | |||||
| if not node_name: | |||||
| return | |||||
| node_type = self.cache_store.get_stream_handler(Streams.GRAPH).get_node_type(node_name) | |||||
| unsupported_types = [item.value for item in list(NodeTypeEnum)] | |||||
| if node_type in unsupported_types: | |||||
| log.error("Invalid node type. %s", node_name) | |||||
| raise DebuggerParamValueError(f"The type of node {node_name} is unsupported for " | |||||
| "continue to command.") | |||||
| def _construct_run_event(self, params): | def _construct_run_event(self, params): | ||||
| """ | """ | ||||
| Construct run cmd from input control params. | Construct run cmd from input control params. | ||||
| @@ -639,7 +626,7 @@ class DebuggerServer: | |||||
| - steps (int): Specify the steps that training should run. | - steps (int): Specify the steps that training should run. | ||||
| Used when `level` is `step`. | Used when `level` is `step`. | ||||
| - full_name (str): Specify the name of the node. Used when `level` is `node`. | |||||
| - name (str): Specify the name of the node. Used when `level` is `node`. | |||||
| Returns: | Returns: | ||||
| EventReply, control event with run command. | EventReply, control event with run command. | ||||
| @@ -652,10 +639,11 @@ class DebuggerServer: | |||||
| steps = 1 | steps = 1 | ||||
| run_cmd = RunCMD(run_level='step', run_steps=steps) | run_cmd = RunCMD(run_level='step', run_steps=steps) | ||||
| elif level == 'node': | elif level == 'node': | ||||
| self._validate_node_type(params.get('name')) | |||||
| name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name( | |||||
| params['name']) | |||||
| if not name: | |||||
| name = params.get('name') | |||||
| if name: | |||||
| self._validate_leaf_name(name) | |||||
| name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name(name) | |||||
| else: | |||||
| name = '' | name = '' | ||||
| run_cmd = RunCMD(run_level='node', node_name=name) | run_cmd = RunCMD(run_level='node', node_name=name) | ||||
| else: | else: | ||||
| @@ -16,7 +16,6 @@ | |||||
| from collections import deque | from collections import deque | ||||
| from mindinsight.datavisual.data_transform.graph.msgraph import MSGraph | from mindinsight.datavisual.data_transform.graph.msgraph import MSGraph | ||||
| from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum | |||||
| from mindinsight.debugger.common.exceptions.exceptions import \ | from mindinsight.debugger.common.exceptions.exceptions import \ | ||||
| DebuggerNodeNotInGraphError, DebuggerParamValueError | DebuggerNodeNotInGraphError, DebuggerParamValueError | ||||
| from mindinsight.debugger.common.log import logger as log | from mindinsight.debugger.common.log import logger as log | ||||
| @@ -90,14 +89,14 @@ class DebuggerGraph(MSGraph): | |||||
| def search_nodes_by_pattern(self, pattern): | def search_nodes_by_pattern(self, pattern): | ||||
| """ | """ | ||||
| Search node names by a given pattern. | |||||
| Search node by a given pattern. | |||||
| Args: | Args: | ||||
| pattern (Union[str, None]): The pattern of the node to search, | pattern (Union[str, None]): The pattern of the node to search, | ||||
| if None, return all node names. | if None, return all node names. | ||||
| Returns: | Returns: | ||||
| list[(str, str)], a list of tuple (node name, node type). | |||||
| list[Node], a list of node. | |||||
| """ | """ | ||||
| if pattern is not None: | if pattern is not None: | ||||
| pattern = pattern.lower() | pattern = pattern.lower() | ||||
| @@ -106,7 +105,7 @@ class DebuggerGraph(MSGraph): | |||||
| if pattern in name.lower() | if pattern in name.lower() | ||||
| ] | ] | ||||
| else: | else: | ||||
| searched_nodes = [node for name, node in self._leaf_nodes.items()] | |||||
| searched_nodes = [node for _, node in self._leaf_nodes.items()] | |||||
| return searched_nodes | return searched_nodes | ||||
| def _build_node_tree(self, node_name, node_type): | def _build_node_tree(self, node_name, node_type): | ||||
| @@ -147,13 +146,8 @@ class DebuggerGraph(MSGraph): | |||||
| if node_name and not self.exist_node(name=node_name): | if node_name and not self.exist_node(name=node_name): | ||||
| raise DebuggerNodeNotInGraphError(node_name=node_name) | raise DebuggerNodeNotInGraphError(node_name=node_name) | ||||
| node = self._leaf_nodes.get(node_name) | |||||
| if node is not None: | |||||
| node_type = node.type | |||||
| else: | |||||
| node_type = NodeTypeEnum.NAME_SCOPE.value | |||||
| return node_type | |||||
| node = self._normal_node_map.get(node_name) | |||||
| return node.type | |||||
| def get_tensor_history(self, node_name, depth=0): | def get_tensor_history(self, node_name, depth=0): | ||||
| """ | """ | ||||
| @@ -71,17 +71,21 @@ class WatchNodeTree: | |||||
| """The property of watch status about current node.""" | """The property of watch status about current node.""" | ||||
| return self._watch_status | return self._watch_status | ||||
| def enable_watch_status(self): | |||||
| """The property of watch status about current node.""" | |||||
| self._watch_status = WatchNodeTree.TOTAL_WATCH | |||||
| def update_metadata(self, node_type, full_name, watch_status): | |||||
| """Update the metadata for watched node.""" | |||||
| self._full_name = full_name | |||||
| self._node_type = self._translate_node_type(node_type) | |||||
| self._watch_status = watch_status | |||||
| @staticmethod | @staticmethod | ||||
| def _translate_node_type(node_type): | def _translate_node_type(node_type): | ||||
| """Translate node type to watch node type.""" | """Translate node type to watch node type.""" | ||||
| if not node_type or node_type == NodeTypeEnum.NAME_SCOPE.value or \ | |||||
| node_type == NodeTypeEnum.AGGREGATION_SCOPE.value: | |||||
| return 'scope' | |||||
| return 'leaf' | |||||
| flag = node_type | |||||
| if not node_type or node_type == NodeTypeEnum.NAME_SCOPE.value: | |||||
| flag = 'scope' | |||||
| elif node_type != NodeTypeEnum.AGGREGATION_SCOPE.value: | |||||
| flag = 'leaf' | |||||
| return flag | |||||
| def get(self, sub_name): | def get(self, sub_name): | ||||
| """Get sub node.""" | """Get sub node.""" | ||||
| @@ -104,10 +108,11 @@ class WatchNodeTree: | |||||
| log.debug("Add node %s with type: %s, full_name: %s", node_name, node_type, full_name) | log.debug("Add node %s with type: %s, full_name: %s", node_name, node_type, full_name) | ||||
| scope_names = node_name.split('/', 1) | scope_names = node_name.split('/', 1) | ||||
| if len(scope_names) == 1: | if len(scope_names) == 1: | ||||
| if not self.get(node_name): | |||||
| target_node = self.get(node_name) | |||||
| if not target_node: | |||||
| self.add(node_name, node_type, full_name, watch_status=WatchNodeTree.TOTAL_WATCH) | self.add(node_name, node_type, full_name, watch_status=WatchNodeTree.TOTAL_WATCH) | ||||
| else: | else: | ||||
| self.get(node_name).enable_watch_status() | |||||
| target_node.update_metadata(node_type, full_name, WatchNodeTree.TOTAL_WATCH) | |||||
| return | return | ||||
| scope_name, sub_names = scope_names | scope_name, sub_names = scope_names | ||||
| @@ -232,7 +237,8 @@ class Watchpoint: | |||||
| cur_watch_node (WatchNodeTree): The current watch node. | cur_watch_node (WatchNodeTree): The current watch node. | ||||
| watch_node_list (list[WatchNodeTree]): The list of total watched node. | watch_node_list (list[WatchNodeTree]): The list of total watched node. | ||||
| """ | """ | ||||
| if cur_watch_node.watch_status == WatchNodeTree.TOTAL_WATCH: | |||||
| if cur_watch_node.watch_status == WatchNodeTree.TOTAL_WATCH and \ | |||||
| cur_watch_node.node_type != NodeTypeEnum.AGGREGATION_SCOPE.value: | |||||
| watch_node_list.append(cur_watch_node) | watch_node_list.append(cur_watch_node) | ||||
| return | return | ||||
| for _, watch_node in cur_watch_node.get_children(): | for _, watch_node in cur_watch_node.get_children(): | ||||
| @@ -143,9 +143,17 @@ class GraphHandler(StreamHandlerBase): | |||||
| return {'nodes': nodes} | return {'nodes': nodes} | ||||
| def get_node_names(self, pattern=None): | |||||
| """Get graph nodes according to pattern.""" | |||||
| return self._graph.search_nodes_by_pattern(pattern) | |||||
| def get_nodes_by_scope(self, scope_name): | |||||
| """ | |||||
| Get node by a given scope name. | |||||
| Args: | |||||
| scope_name (str): The name of scope. | |||||
| Returns: | |||||
| list[Node], a list of node. | |||||
| """ | |||||
| return self._graph.search_nodes_by_pattern(scope_name) | |||||
| def get_searched_node_list(self): | def get_searched_node_list(self): | ||||
| """Get searched node list.""" | """Get searched node list.""" | ||||
| @@ -226,9 +226,10 @@ class TensorHandler(StreamHandlerBase): | |||||
| def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type): | def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type): | ||||
| """Update has_prev_step field in tensor info.""" | """Update has_prev_step field in tensor info.""" | ||||
| flag = None | flag = None | ||||
| if node_type == NodeTypeEnum.PARAMETER.value: | |||||
| cur_tensor_value = bool(tensor_info and tensor_info.get('value') is not None) | |||||
| if node_type == NodeTypeEnum.PARAMETER.value and cur_tensor_value: | |||||
| flag = self._get_prev_tensor_value_status(tensor_name) | flag = self._get_prev_tensor_value_status(tensor_name) | ||||
| if flag and tensor_info: | |||||
| if flag: | |||||
| tensor_info['has_prev_step'] = True | tensor_info['has_prev_step'] = True | ||||
| return flag | return flag | ||||