| @@ -52,6 +52,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| self._status = None | |||
| self._continue_steps = None | |||
| self._received_view_cmd = None | |||
| self._received_hit = None | |||
| self.init() | |||
| def init(self): | |||
| @@ -60,6 +61,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| self._status = ServerStatus.PENDING | |||
| self._continue_steps = 0 | |||
| self._received_view_cmd = {} | |||
| self._received_hit = False | |||
| self._cache_store.clean() | |||
| @debugger_wrap | |||
| @@ -152,8 +154,9 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| def _send_watchpoint_hit_flag(self): | |||
| """Send Watchpoint hit flag.""" | |||
| watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) | |||
| if watchpoint_hit_stream.empty: | |||
| if watchpoint_hit_stream.empty or not self._received_hit: | |||
| return | |||
| self._received_hit = False | |||
| watchpoint_hits_info = watchpoint_hit_stream.get() | |||
| self._cache_store.put_data(watchpoint_hits_info) | |||
| log.debug("Send the watchpoint hits to DataQueue.\nSend the reply.") | |||
| @@ -302,6 +305,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| """Send watchpoint hits info DebuggerCache.""" | |||
| log.info("Received WatchpointHits. Left steps %d change to 0.", self._continue_steps) | |||
| self._continue_steps = 0 | |||
| self._received_hit = True | |||
| watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) | |||
| watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT) | |||
| graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) | |||
| @@ -108,9 +108,10 @@ class DebuggerServer: | |||
| log.info("receive search request for node:%s, in watchpoint:%d", name, watch_point_id) | |||
| watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) | |||
| watchpoint_stream.validate_watchpoint_id(watch_point_id) | |||
| graph = self.cache_store.get_stream_handler(Streams.GRAPH).search_nodes(name) | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| graph = graph_stream.search_nodes(name) | |||
| # add watched label to graph | |||
| watchpoint_stream.set_watch_nodes(graph, watch_point_id) | |||
| watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id) | |||
| return graph | |||
| def tensor_comparisons(self, name, shape, detail='data', tolerance='0'): | |||
| @@ -247,7 +248,7 @@ class DebuggerServer: | |||
| reply = graph_stream.get(filter_condition) | |||
| graph = reply.get('graph') | |||
| # add watched label to graph | |||
| watchpoint_stream.set_watch_nodes(graph, watch_point_id) | |||
| watchpoint_stream.set_watch_nodes(graph, graph_stream, watch_point_id) | |||
| return reply | |||
| def retrieve_tensor_history(self, node_name): | |||
| @@ -289,11 +289,6 @@ class WatchpointHit: | |||
| tensor_name = ':'.join([self._full_name, self._slot]) | |||
| return tensor_name | |||
| @property | |||
| def tensor_name(self): | |||
| """The property of tensor_name.""" | |||
| return ':'.join([self._node_name, self._slot]) | |||
| @property | |||
| def watchpoint(self): | |||
| """The property of watchpoint.""" | |||
| @@ -99,7 +99,7 @@ class TensorHandler(StreamHandlerBase): | |||
| old_tensor = cache_tensor.get(step) | |||
| if old_tensor and not self.is_value_diff(old_tensor.value, tensor.value): | |||
| log.debug("Tensor %s of step %s has no change. Ignore it.") | |||
| log.debug("Tensor %s of step %s has no change. Ignore it.", tensor.name, step) | |||
| return False | |||
| cache_tensor[step] = tensor | |||
| log.debug("Put updated tensor value for %s of step %s.", tensor.name, step) | |||
| @@ -88,27 +88,28 @@ class WatchpointHandler(StreamHandlerBase): | |||
| return {'watch_points': reply} | |||
| def set_watch_nodes(self, graph, watch_point_id): | |||
| def set_watch_nodes(self, graph, graph_stream, watch_point_id): | |||
| """ | |||
| set watch nodes for graph. | |||
| Args: | |||
| graph (dict): The graph with list of nodes. | |||
| graph_stream (GraphHandler): The graph handler. | |||
| watch_point_id (int): The id of watchpoint. | |||
| """ | |||
| if not (watch_point_id and graph): | |||
| return | |||
| log.debug("add watch flags") | |||
| watchpoint = self._watchpoints.get(watch_point_id) | |||
| self._set_watch_status_recursively(graph, watchpoint) | |||
| self._set_watch_status_recursively(graph, graph_stream, watchpoint) | |||
| def _set_watch_status_recursively(self, graph, watchpoint): | |||
| def _set_watch_status_recursively(self, graph, graph_stream, watchpoint): | |||
| """Set watch status to graph.""" | |||
| if not isinstance(graph, dict): | |||
| log.warning("The graph is not dict.") | |||
| return | |||
| if graph.get('children'): | |||
| self._set_watch_status_recursively(graph.get('children'), watchpoint) | |||
| self._set_watch_status_recursively(graph.get('children'), graph_stream, watchpoint) | |||
| for node in graph.get('nodes', []): | |||
| if not isinstance(node, dict): | |||
| @@ -117,10 +118,11 @@ class WatchpointHandler(StreamHandlerBase): | |||
| node_name = node.get('name') | |||
| if not node_name: | |||
| continue | |||
| flag = watchpoint.get_node_status(node_name, node.get('type'), node.get('full_name')) | |||
| full_name = graph_stream.get_full_name(node_name) | |||
| flag = watchpoint.get_node_status(node_name, node.get('type'), full_name) | |||
| node['watched'] = flag | |||
| if node.get('nodes'): | |||
| self._set_watch_status_recursively(node, watchpoint) | |||
| self._set_watch_status_recursively(node, graph_stream, watchpoint) | |||
| def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None): | |||
| """ | |||
| @@ -220,6 +222,8 @@ class WatchpointHitHandler(StreamHandlerBase): | |||
| - tensor_proto (TensorProto): The message about hit tensor. | |||
| - watchpoint (Watchpoint): The Watchpoint that a node hit. | |||
| - node_name (str): The UI node name. | |||
| """ | |||
| watchpoint_hit = WatchpointHit( | |||
| tensor_proto=value.get('tensor_proto'), | |||
| @@ -268,14 +272,22 @@ class WatchpointHitHandler(StreamHandlerBase): | |||
| return {'watch_point_hits': watch_point_hits} | |||
| def _is_tensor_hit(self, tensor_name): | |||
| """Check if the tensor is record in hit cache.""" | |||
| """ | |||
| Check if the tensor is record in hit cache. | |||
| Args: | |||
| tensor_name (str): The name of full tensor name. | |||
| Returns: | |||
| bool, if the tensor is hit. | |||
| """ | |||
| node_name = tensor_name.split(':')[0] | |||
| watchpoint_hits = self.get(node_name) | |||
| if watchpoint_hits is None: | |||
| return False | |||
| for watchpoint_hit in watchpoint_hits: | |||
| if tensor_name == watchpoint_hit.tensor_name: | |||
| if tensor_name == watchpoint_hit.tensor_full_name: | |||
| return True | |||
| return False | |||