diff --git a/mindinsight/debugger/debugger_grpc_server.py b/mindinsight/debugger/debugger_grpc_server.py index 5ac75ee1..d2ac92d2 100644 --- a/mindinsight/debugger/debugger_grpc_server.py +++ b/mindinsight/debugger/debugger_grpc_server.py @@ -99,7 +99,6 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): def _pre_process(self, request): """Pre-process before dealing with command.""" metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) - watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT) is_new_step = metadata_stream.step < request.cur_step is_new_node = metadata_stream.full_name != request.cur_node # clean cache data at the beginning of new step or node has been changed. @@ -108,15 +107,12 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): if is_new_step: self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors(request.cur_step) - watchpoint_stream.clean_temp_cached_names() # receive graph at the beginning of the training if self._status == ServerStatus.RECEIVE_GRAPH: self._send_graph_flag(metadata_stream) # receive new metadata if is_new_step or is_new_node: self._update_metadata(metadata_stream, request) - # save the full name of the node which MindSpore has stored the tensor. - watchpoint_stream.add_temp_cached_name(request.cur_node) self._send_received_tensor_tag() self._send_watchpoint_hit_flag() diff --git a/mindinsight/debugger/stream_cache/watchpoint.py b/mindinsight/debugger/stream_cache/watchpoint.py index 6001317c..4d0e57c6 100644 --- a/mindinsight/debugger/stream_cache/watchpoint.py +++ b/mindinsight/debugger/stream_cache/watchpoint.py @@ -13,13 +13,15 @@ # limitations under the License. # ============================================================================ """Define the watchpoint stream.""" -from mindinsight.debugger.conditionmgr.common.utils import NodeBasicInfo + +import copy + from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError from mindinsight.debugger.common.log import LOGGER as log from mindinsight.debugger.common.utils import is_scope_type -from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD, WatchCondition +from mindinsight.debugger.conditionmgr.common.utils import NodeBasicInfo from mindinsight.debugger.conditionmgr.condition import ConditionIdEnum - +from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD, WatchCondition WATCHPOINT_CONDITION_MAPPING = { ConditionIdEnum.NAN.value: WatchCondition.Condition.nan, @@ -109,7 +111,7 @@ class WatchNodeTree: return self._children.get(sub_name) def get_children(self): - """Get all childrens.""" + """Get all children.""" for name_scope, sub_watch_node in self._children.items(): yield name_scope, sub_watch_node @@ -198,13 +200,17 @@ class Watchpoint: """The property of watch condition.""" return self._condition - def copy_nodes_from(self, other_watchpoint): + def copy_nodes_from(self, other_watchpoint, deep_copy=False): """ Copy nodes from other watchpoint. Args: other_watchpoint (Watchpoint): Other watchpoint. + deep_copy (bool): Whether using deepcopy. """ - self._watch_node = other_watchpoint.nodes + if deep_copy: + self._watch_node = copy.deepcopy(other_watchpoint.nodes) + else: + self._watch_node = other_watchpoint.nodes def add_nodes(self, nodes): """Add node into watchpoint.""" diff --git a/mindinsight/debugger/stream_handler/watchpoint_handler.py b/mindinsight/debugger/stream_handler/watchpoint_handler.py index bcbfcd6e..3f747ecc 100644 --- a/mindinsight/debugger/stream_handler/watchpoint_handler.py +++ b/mindinsight/debugger/stream_handler/watchpoint_handler.py @@ -33,19 +33,13 @@ class WatchpointHandler(StreamHandlerBase): self._created_watchpoints = [] # list of SetCMD of watchpoints to be deleted self._deleted_watchpoints = [] - # dict of of watchpoint to be updated + # dict of of watchpoints to be updated self._updated_watchpoints = {} # the collection of watched node full names, which have been sent to MindSpore - self._all_watched_node_full_names = set() - # the collection of new watched node full names, which have not been sent to MindSpore - self._new_watched_node_full_names = set() - # record the temp stored nodes in MS, which could be set as watch node for recheck on GPU - # should be clean at the beginning of each step - self._temp_cached_node_full_names = set() self._latest_id = 0 self._cache_set_cmd = {} # whether the watchpoint list has been changed since last step - self.outdated = False + self._outdated = False def put(self, value): """ @@ -61,18 +55,9 @@ class WatchpointHandler(StreamHandlerBase): self._latest_id = new_id log.debug("Put watchpoint %d into cache.", new_id) - def clean_temp_cached_names(self): - """Clean temp cached node.""" - self._temp_cached_node_full_names.clear() - - def add_temp_cached_name(self, node_full_name): - """Add temp stored node in cache.""" - if node_full_name: - self._temp_cached_node_full_names.add(node_full_name) - def sync_set_cmd(self, set_cmds): """Clean temp watchpoints.""" - self._new_watched_node_full_names = set() + self._outdated = False self._created_watchpoints = [] self._deleted_watchpoints = [] self._updated_watchpoints = {} @@ -126,20 +111,14 @@ class WatchpointHandler(StreamHandlerBase): list[SetCMD], updated watchpoint to be sent to MindSpore. """ res = [] - new_watched_nodes = set() - self._all_watched_node_full_names.clear() for _, watchpoint in self._updated_watchpoints.items(): # construct set command with leaf nodes watch_nodes = watchpoint.get_watch_nodes() leaf_watch_nodes = self._expand_to_leaf_nodes(graph_stream, watch_nodes) res.append(watchpoint.get_pending_cmd(leaf_watch_nodes)) - # update all watched node names - watch_node_names = [watch_node.full_name for watch_node in [*watch_nodes, *leaf_watch_nodes]] - new_watched_nodes.update(watch_node_names) res.extend(self._deleted_watchpoints) for _, set_cmd in self._cache_set_cmd.items(): res.append(set_cmd) - self._all_watched_node_full_names = new_watched_nodes return res @staticmethod @@ -168,23 +147,14 @@ class WatchpointHandler(StreamHandlerBase): leaf_watch_nodes.append(node) return leaf_watch_nodes - def is_recheckable(self, backend=None): + def is_recheckable(self): """ Check if current status is able to recheck. - Args: - backend (str): The backend info. 'Ascend' or 'GPU'. Default: None. - Returns: bool, if enable to recheck. """ - enable_recheck = self.outdated - if backend == 'GPU' and enable_recheck: - # on GPU, disable to recheck if there are new watched node of which the tensor - # has not been stored on MindSpore - diff_set = self._new_watched_node_full_names - self._all_watched_node_full_names - enable_recheck = not diff_set or diff_set.issubset(self._temp_cached_node_full_names) - return enable_recheck + return self._outdated def set_watch_nodes(self, graph, graph_stream, watch_point_id, graph_name=None): """ @@ -274,12 +244,11 @@ class WatchpointHandler(StreamHandlerBase): watchpoint = Watchpoint(new_id, watch_condition) if watch_nodes: watchpoint.add_nodes(watch_nodes) - self._add_watch_node_in_cache(watch_nodes) elif watch_point_id: self.validate_watchpoint_id(watch_point_id) watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id)) self.put(watchpoint) - self.outdated = True + self._outdated = True return new_id def update_watchpoint(self, watch_point_id, watch_nodes, watched=False): @@ -296,12 +265,10 @@ class WatchpointHandler(StreamHandlerBase): watchpoint = self._watchpoints.get(watch_point_id) if watched: watchpoint.add_nodes(watch_nodes) - self._add_watch_node_in_cache(watch_nodes) else: watchpoint.remove_nodes(watch_nodes) - self._remove_watch_node_from_cache(watch_nodes) self._updated_watchpoints[watch_point_id] = watchpoint - self.outdated = True + self._outdated = True log.debug("Update watchpoint %d in cache.", watch_point_id) def delete_watchpoint(self, watch_point_id=None): @@ -319,7 +286,7 @@ class WatchpointHandler(StreamHandlerBase): watch_point_ids = [watch_point_id] for single_id in watch_point_ids: self._delete_single_watchpoint(single_id) - self.outdated = True + self._outdated = True def _delete_single_watchpoint(self, watch_point_id): """ @@ -350,27 +317,6 @@ class WatchpointHandler(StreamHandlerBase): log.error("Invalid watchpoint id: %d.", watch_point_id) raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id)) - def _add_watch_node_in_cache(self, watch_nodes): - """ - Add watch nodes in cache. - - Args: - watch_nodes (list[NodeBasicInfo]): The list of node basic info. - """ - node_full_names = [node.full_name for node in watch_nodes] - self._new_watched_node_full_names.update(node_full_names) - - def _remove_watch_node_from_cache(self, watch_nodes): - """ - Remove watch nodes from cache. - - Args: - watch_nodes (list[NodeBasicInfo]): The list of node basic info. - """ - for node in watch_nodes: - if node.full_name in self._new_watched_node_full_names: - self._new_watched_node_full_names.remove(node.full_name) - class WatchpointHitHandler(StreamHandlerBase): """Watchpoint hit handler.""" diff --git a/mindinsight/debugger/stream_operator/watchpoint_operator.py b/mindinsight/debugger/stream_operator/watchpoint_operator.py index 897baee0..81d20d6e 100644 --- a/mindinsight/debugger/stream_operator/watchpoint_operator.py +++ b/mindinsight/debugger/stream_operator/watchpoint_operator.py @@ -87,7 +87,7 @@ class WatchpointOperator: self._condition_mgr, watch_condition, watch_nodes, params.get('watch_point_id')) log.info("Create watchpoint %d", watch_point_id) - metadata_stream.enable_recheck = watchpoint_stream.is_recheckable(metadata_stream.backend) + metadata_stream.enable_recheck = watchpoint_stream.is_recheckable() res = metadata_stream.get(['state', 'enable_recheck']) res['id'] = watch_point_id return res @@ -140,7 +140,7 @@ class WatchpointOperator: search_pattern=params.get('search_pattern'), graph_name=params.get('graph_name')) watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, params.get('mode')) - metadata_stream.enable_recheck = watchpoint_stream.is_recheckable(metadata_stream.backend) + metadata_stream.enable_recheck = watchpoint_stream.is_recheckable() log.info("Update watchpoint with id: %d", watch_point_id) return metadata_stream.get(['state', 'enable_recheck']) diff --git a/tests/st/func/debugger/test_restful_api.py b/tests/st/func/debugger/test_restful_api.py index 65a8bbc6..c3635565 100644 --- a/tests/st/func/debugger/test_restful_api.py +++ b/tests/st/func/debugger/test_restful_api.py @@ -434,7 +434,7 @@ class TestGPUDebugger: @pytest.mark.parametrize("url, body_data, enable_recheck", [ ('create_watchpoint', {'condition': {'id': 'inf', 'params': []}, - 'watch_nodes': ['Default']}, False), + 'watch_nodes': ['Default']}, True), ('create_watchpoint', {'condition': {'id': 'inf', 'params': []}, 'watch_nodes': ['Default/TransData-op99']}, True), @@ -443,7 +443,7 @@ class TestGPUDebugger: 'mode': 0}, True), ('update_watchpoint', {'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'], - 'mode': 1}, False), + 'mode': 1}, True), ('update_watchpoint', [{'watch_point_id': 1, 'watch_nodes': ['Default/optimizer-Momentum'], 'mode': 1},