Browse Source

fix the bug of watchpoint hit

tags/v1.0.0
yelihua 5 years ago
parent
commit
1b254d5b31
5 changed files with 30 additions and 18 deletions
  1. +5
    -1
      mindinsight/debugger/debugger_grpc_server.py
  2. +4
    -3
      mindinsight/debugger/debugger_server.py
  3. +0
    -5
      mindinsight/debugger/stream_cache/watchpoint.py
  4. +1
    -1
      mindinsight/debugger/stream_handler/tensor_handler.py
  5. +20
    -8
      mindinsight/debugger/stream_handler/watchpoint_handler.py

+ 5
- 1
mindinsight/debugger/debugger_grpc_server.py View File

@@ -52,6 +52,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
self._status = None
self._continue_steps = None
self._received_view_cmd = None
self._received_hit = None
self.init()

def init(self):
@@ -60,6 +61,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
self._status = ServerStatus.PENDING
self._continue_steps = 0
self._received_view_cmd = {}
self._received_hit = False
self._cache_store.clean()

@debugger_wrap
@@ -152,8 +154,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
def _send_watchpoint_hit_flag(self):
"""Send Watchpoint hit flag."""
watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
if watchpoint_hit_stream.empty:
if watchpoint_hit_stream.empty or not self._received_hit:
return
self._received_hit = False
watchpoint_hits_info = watchpoint_hit_stream.get()
self._cache_store.put_data(watchpoint_hits_info)
log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.")
@@ -302,6 +305,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
"""Send watchpoint hits info DebuggerCache."""
log.info("Received WatchpointHits. Left steps %d change to 0.", self._continue_steps)
self._continue_steps = 0
self._received_hit = True
watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT)
watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT)
graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH)


+ 4
- 3
mindinsight/debugger/debugger_server.py View File

@@ -108,9 +108,10 @@ class DebuggerServer:
log.info("receive search request for node:%s, in watchpoint:%d", name, watch_point_id)
watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
watchpoint_stream.validate_watchpoint_id(watch_point_id)
graph = self.cache_store.get_stream_handler(Streams.GRAPH).search_nodes(name)
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
graph = graph_stream.search_nodes(name)
# add watched label to graph
watchpoint_stream.set_watch_nodes(graph, watch_point_id)
watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id)
return graph

def tensor_comparisons(self, name, shape, detail='data', tolerance='0'):
@@ -247,7 +248,7 @@ class DebuggerServer:
reply = graph_stream.get(filter_condition)
graph = reply.get('graph')
# add watched label to graph
watchpoint_stream.set_watch_nodes(graph, watch_point_id)
watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id)
return reply

def retrieve_tensor_history(self, node_name):


+ 0
- 5
mindinsight/debugger/stream_cache/watchpoint.py View File

@@ -289,11 +289,6 @@ class WatchpointHit:
tensor_name = ':'.join([self._full_name, self._slot])
return tensor_name

@property
def tensor_name(self):
"""The property of tensor_name."""
return ':'.join([self._node_name, self._slot])

@property
def watchpoint(self):
"""The property of watchpoint."""


+ 1
- 1
mindinsight/debugger/stream_handler/tensor_handler.py View File

@@ -99,7 +99,7 @@ class TensorHandler(StreamHandlerBase):

old_tensor = cache_tensor.get(step)
if old_tensor and not self.is_value_diff(old_tensor.value, tensor.value):
log.debug("Tensor %s of step %s has no change. Ignore it.")
log.debug("Tensor %s of step %s has no change. Ignore it.", tensor.name, step)
return False
cache_tensor[step] = tensor
log.debug("Put updated tensor value for %s of step %s.", tensor.name, step)


+ 20
- 8
mindinsight/debugger/stream_handler/watchpoint_handler.py View File

@@ -88,27 +88,28 @@ class WatchpointHandler(StreamHandlerBase):

return {'watch_points': reply}

def set_watch_nodes(self, graph, watch_point_id):
def set_watch_nodes(self, graph, graph_stream, watch_point_id):
"""
set watch nodes for graph.

Args:
graph (dict): The graph with list of nodes.
graph_stream (GraphHandler): The graph handler.
watch_point_id (int): The id of watchpoint.
"""
if not (watch_point_id and graph):
return
log.debug("add watch flags")
watchpoint = self._watchpoints.get(watch_point_id)
self._set_watch_status_recursively(graph, watchpoint)
self._set_watch_status_recursively(graph, graph_stream, watchpoint)

def _set_watch_status_recursively(self, graph, watchpoint):
def _set_watch_status_recursively(self, graph, graph_stream, watchpoint):
"""Set watch status to graph."""
if not isinstance(graph, dict):
log.warning("The graph is not dict.")
return
if graph.get('children'):
self._set_watch_status_recursively(graph.get('children'), watchpoint)
self._set_watch_status_recursively(graph.get('children'), graph_stream, watchpoint)

for node in graph.get('nodes', []):
if not isinstance(node, dict):
@@ -117,10 +118,11 @@ class WatchpointHandler(StreamHandlerBase):
node_name = node.get('name')
if not node_name:
continue
flag = watchpoint.get_node_status(node_name, node.get('type'), node.get('full_name'))
full_name = graph_stream.get_full_name(node_name)
flag = watchpoint.get_node_status(node_name, node.get('type'), full_name)
node['watched'] = flag
if node.get('nodes'):
self._set_watch_status_recursively(node, watchpoint)
self._set_watch_status_recursively(node, graph_stream, watchpoint)

def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None):
"""
@@ -220,6 +222,8 @@ class WatchpointHitHandler(StreamHandlerBase):
- tensor_proto (TensorProto): The message about hit tensor.

- watchpoint (Watchpoint): The Watchpoint that a node hit.

- node_name (str): The UI node name.
"""
watchpoint_hit = WatchpointHit(
tensor_proto=value.get('tensor_proto'),
@@ -268,14 +272,22 @@ class WatchpointHitHandler(StreamHandlerBase):
return {'watch_point_hits': watch_point_hits}

def _is_tensor_hit(self, tensor_name):
"""Check if the tensor is record in hit cache."""
"""
Check if the tensor is record in hit cache.

Args:
tensor_name (str): The name of full tensor name.

Returns:
bool, if the tensor is hit.
"""
node_name = tensor_name.split(':')[0]
watchpoint_hits = self.get(node_name)
if watchpoint_hits is None:
return False

for watchpoint_hit in watchpoint_hits:
if tensor_name == watchpoint_hit.tensor_name:
if tensor_name == watchpoint_hit.tensor_full_name:
return True

return False


Loading…
Cancel
Save