Browse Source

fix the bug for update watchpoint

tags/v1.0.0
yelihua 5 years ago
parent
commit
35e3bee5aa
6 changed files with 43 additions and 46 deletions
  1. +1
    -1
      mindinsight/debugger/common/utils.py
  2. +7
    -19
      mindinsight/debugger/debugger_server.py
  3. +5
    -11
      mindinsight/debugger/stream_cache/debugger_graph.py
  4. +16
    -10
      mindinsight/debugger/stream_cache/watchpoint.py
  5. +11
    -3
      mindinsight/debugger/stream_handler/graph_handler.py
  6. +3
    -2
      mindinsight/debugger/stream_handler/tensor_handler.py

+ 1
- 1
mindinsight/debugger/common/utils.py View File

@@ -18,8 +18,8 @@ from collections import namedtuple

import numpy as np

from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply
from mindinsight.debugger.stream_cache.debugger_graph import NodeTypeEnum

# translate the MindSpore type to numpy type.
NUMPY_TYPE_MAP = {


+ 7
- 19
mindinsight/debugger/debugger_server.py View File

@@ -536,13 +536,11 @@ class DebuggerServer:
node_infos = []
for node_name in node_names:
node_type = graph_stream.get_node_type(node_name)
# optimizer later
if node_type == NodeTypeEnum.AGGREGATION_SCOPE.value:
sub_nodes = graph_stream.get_nodes(node_name)
sub_nodes = graph_stream.get_nodes_by_scope(node_name)
sub_infos = [NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type)
for node in sub_nodes]
node_infos.extend(sub_infos)
continue
full_name = graph_stream.get_full_name(node_name)
node_infos.append(NodeBasicInfo(name=node_name, full_name=full_name, type=node_type))
return node_infos
@@ -615,17 +613,6 @@ class DebuggerServer:

return {'metadata': {'state': current_state}}

def _validate_node_type(self, node_name):
"""Check the node type in node control."""
if not node_name:
return
node_type = self.cache_store.get_stream_handler(Streams.GRAPH).get_node_type(node_name)
unsupported_types = [item.value for item in list(NodeTypeEnum)]
if node_type in unsupported_types:
log.error("Invalid node type. %s", node_name)
raise DebuggerParamValueError(f"The type of node {node_name} is unsupported for "
"continue to command.")

def _construct_run_event(self, params):
"""
Construct run cmd from input control params.
@@ -639,7 +626,7 @@ class DebuggerServer:
- steps (int): Specify the steps that training should run.
Used when `level` is `step`.

- full_name (str): Specify the name of the node. Used when `level` is `node`.
- name (str): Specify the name of the node. Used when `level` is `node`.

Returns:
EventReply, control event with run command.
@@ -652,10 +639,11 @@ class DebuggerServer:
steps = 1
run_cmd = RunCMD(run_level='step', run_steps=steps)
elif level == 'node':
self._validate_node_type(params.get('name'))
name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name(
params['name'])
if not name:
name = params.get('name')
if name:
self._validate_leaf_name(name)
name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name(name)
else:
name = ''
run_cmd = RunCMD(run_level='node', node_name=name)
else:


+ 5
- 11
mindinsight/debugger/stream_cache/debugger_graph.py View File

@@ -16,7 +16,6 @@
from collections import deque

from mindinsight.datavisual.data_transform.graph.msgraph import MSGraph
from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum
from mindinsight.debugger.common.exceptions.exceptions import \
DebuggerNodeNotInGraphError, DebuggerParamValueError
from mindinsight.debugger.common.log import logger as log
@@ -90,14 +89,14 @@ class DebuggerGraph(MSGraph):

def search_nodes_by_pattern(self, pattern):
"""
Search node names by a given pattern.
Search node by a given pattern.

Args:
pattern (Union[str, None]): The pattern of the node to search,
if None, return all node names.

Returns:
list[(str, str)], a list of tuple (node name, node type).
list[Node], a list of node.
"""
if pattern is not None:
pattern = pattern.lower()
@@ -106,7 +105,7 @@ class DebuggerGraph(MSGraph):
if pattern in name.lower()
]
else:
searched_nodes = [node for name, node in self._leaf_nodes.items()]
searched_nodes = [node for _, node in self._leaf_nodes.items()]
return searched_nodes

def _build_node_tree(self, node_name, node_type):
@@ -147,13 +146,8 @@ class DebuggerGraph(MSGraph):
if node_name and not self.exist_node(name=node_name):
raise DebuggerNodeNotInGraphError(node_name=node_name)

node = self._leaf_nodes.get(node_name)
if node is not None:
node_type = node.type
else:
node_type = NodeTypeEnum.NAME_SCOPE.value

return node_type
node = self._normal_node_map.get(node_name)
return node.type

def get_tensor_history(self, node_name, depth=0):
"""


+ 16
- 10
mindinsight/debugger/stream_cache/watchpoint.py View File

@@ -71,17 +71,21 @@ class WatchNodeTree:
"""The property of watch status about current node."""
return self._watch_status

def enable_watch_status(self):
"""The property of watch status about current node."""
self._watch_status = WatchNodeTree.TOTAL_WATCH
def update_metadata(self, node_type, full_name, watch_status):
"""Update the metadata for watched node."""
self._full_name = full_name
self._node_type = self._translate_node_type(node_type)
self._watch_status = watch_status

@staticmethod
def _translate_node_type(node_type):
"""Translate node type to watch node type."""
if not node_type or node_type == NodeTypeEnum.NAME_SCOPE.value or \
node_type == NodeTypeEnum.AGGREGATION_SCOPE.value:
return 'scope'
return 'leaf'
flag = node_type
if not node_type or node_type == NodeTypeEnum.NAME_SCOPE.value:
flag = 'scope'
elif node_type != NodeTypeEnum.AGGREGATION_SCOPE.value:
flag = 'leaf'
return flag

def get(self, sub_name):
"""Get sub node."""
@@ -104,10 +108,11 @@ class WatchNodeTree:
log.debug("Add node %s with type: %s, full_name: %s", node_name, node_type, full_name)
scope_names = node_name.split('/', 1)
if len(scope_names) == 1:
if not self.get(node_name):
target_node = self.get(node_name)
if not target_node:
self.add(node_name, node_type, full_name, watch_status=WatchNodeTree.TOTAL_WATCH)
else:
self.get(node_name).enable_watch_status()
target_node.update_metadata(node_type, full_name, WatchNodeTree.TOTAL_WATCH)
return

scope_name, sub_names = scope_names
@@ -232,7 +237,8 @@ class Watchpoint:
cur_watch_node (WatchNodeTree): The current watch node.
watch_node_list (list[WatchNodeTree]): The list of total watched node.
"""
if cur_watch_node.watch_status == WatchNodeTree.TOTAL_WATCH:
if cur_watch_node.watch_status == WatchNodeTree.TOTAL_WATCH and \
cur_watch_node.node_type != NodeTypeEnum.AGGREGATION_SCOPE.value:
watch_node_list.append(cur_watch_node)
return
for _, watch_node in cur_watch_node.get_children():


+ 11
- 3
mindinsight/debugger/stream_handler/graph_handler.py View File

@@ -143,9 +143,17 @@ class GraphHandler(StreamHandlerBase):

return {'nodes': nodes}

def get_node_names(self, pattern=None):
"""Get graph nodes according to pattern."""
return self._graph.search_nodes_by_pattern(pattern)
def get_nodes_by_scope(self, scope_name):
"""
Get node by a given scope name.

Args:
scope_name (str): The name of scope.

Returns:
list[Node], a list of node.
"""
return self._graph.search_nodes_by_pattern(scope_name)

def get_searched_node_list(self):
"""Get searched node list."""


+ 3
- 2
mindinsight/debugger/stream_handler/tensor_handler.py View File

@@ -226,9 +226,10 @@ class TensorHandler(StreamHandlerBase):
def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type):
"""Update has_prev_step field in tensor info."""
flag = None
if node_type == NodeTypeEnum.PARAMETER.value:
cur_tensor_value = bool(tensor_info and tensor_info.get('value') is not None)
if node_type == NodeTypeEnum.PARAMETER.value and cur_tensor_value:
flag = self._get_prev_tensor_value_status(tensor_name)
if flag and tensor_info:
if flag:
tensor_info['has_prev_step'] = True
return flag



Loading…
Cancel
Save