| @@ -213,15 +213,9 @@ def create_watchpoint(): | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/create_watchpoint | |||
| """ | |||
| body = _read_post_request(request) | |||
| condition = body.get('condition') | |||
| graph_name = body.get('graph_name') | |||
| watch_nodes = body.get('watch_nodes') | |||
| watch_point_id = body.get('watch_point_id') | |||
| search_pattern = body.get('search_pattern') | |||
| reply = _wrap_reply(BACKEND_SERVER.create_watchpoint, | |||
| condition, watch_nodes, watch_point_id, search_pattern, graph_name) | |||
| params = _read_post_request(request) | |||
| params['watch_condition'] = params.pop('condition', None) | |||
| reply = _wrap_reply(BACKEND_SERVER.create_watchpoint, params) | |||
| return reply | |||
| @@ -239,14 +233,8 @@ def update_watchpoint(): | |||
| Examples: | |||
| >>> POST http://xxxx/v1/mindinsight/debugger/update_watchpoint | |||
| """ | |||
| body = _read_post_request(request) | |||
| watch_point_id = body.get('watch_point_id') | |||
| watch_nodes = body.get('watch_nodes') | |||
| graph_name = body.get('graph_name') | |||
| mode = body.get('mode') | |||
| pattern = body.get('search_pattern') | |||
| reply = _wrap_reply(BACKEND_SERVER.update_watchpoint, watch_point_id, watch_nodes, mode, pattern, graph_name) | |||
| params = _read_post_request(request) | |||
| reply = _wrap_reply(BACKEND_SERVER.update_watchpoint, params) | |||
| return reply | |||
| @@ -15,21 +15,20 @@ | |||
| """Implement the debugger server.""" | |||
| import signal | |||
| from concurrent import futures | |||
| from functools import wraps | |||
| from threading import Thread | |||
| import grpc | |||
| from mindinsight.conditionmgr.common.utils import NodeBasicInfo | |||
| from mindinsight.conditionmgr.condition import ConditionContext, ConditionIdEnum | |||
| from mindinsight.conditionmgr.condition import ConditionContext | |||
| from mindinsight.conditionmgr.conditionmgr import ConditionMgr | |||
| from mindinsight.conditionmgr.recommender import recommend_watchpoints | |||
| from mindinsight.conf import settings | |||
| from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | |||
| from mindinsight.datavisual.utils.tools import to_float | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | |||
| DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \ | |||
| DebuggerDeleteWatchPointError, DebuggerCompareTensorError, DebuggerTensorGraphError, \ | |||
| DebuggerTensorHitError | |||
| DebuggerParamTypeError, DebuggerCompareTensorError, DebuggerTensorGraphError, \ | |||
| DebuggerTensorHitError, MindInsightException | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.common.utils import ServerStatus, \ | |||
| create_view_event_from_tensor_basic_info, Streams | |||
| @@ -38,9 +37,26 @@ from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer | |||
| from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base | |||
| from mindinsight.debugger.stream_operator.tensor_detail_info import TensorDetailInfo | |||
| from mindinsight.debugger.stream_operator.training_control_operator import TrainingControlOperator | |||
| from mindinsight.debugger.stream_operator.watchpoint_operator import WatchpointOperator | |||
| from mindinsight.utils.tensor import TensorUtils, MAX_DIMENSIONS_FOR_TENSOR | |||
| def try_except(func): | |||
| """Send latest metadata when catch exception.""" | |||
| @wraps(func) | |||
| def send_latest_metadata(self, *args, **kwargs): | |||
| try: | |||
| return func(self, *args, **kwargs) | |||
| except MindInsightException as err: | |||
| metadata = self.cache_store.get_stream_handler(Streams.METADATA).get() | |||
| self.cache_store.put_data(metadata) | |||
| log.info("Put latest metadata into data-queue.") | |||
| raise err | |||
| return send_latest_metadata | |||
| class DebuggerServer: | |||
| """The server manager of debugger.""" | |||
| @@ -488,131 +504,57 @@ class DebuggerServer: | |||
| return reply | |||
| def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None, search_pattern=None, | |||
| graph_name=None): | |||
| def create_watchpoint(self, params): | |||
| """ | |||
| Create watchpoint. | |||
| Args: | |||
| watch_condition (dict): The watch condition. The format is like: | |||
| { | |||
| "id": "tensor_too_large", | |||
| "params": [ | |||
| { | |||
| "name": "abs_mean_gt", | |||
| "disable": false, | |||
| "value": 1.1 | |||
| } | |||
| ] | |||
| } | |||
| - id (str): Id of condition. | |||
| - params (list[dict]): The list of param for this condition. | |||
| watch_nodes (list[str]): The list of node names. | |||
| watch_point_id (int): The id of watchpoint. | |||
| search_pattern (dict): The search pattern. Default: None. | |||
| graph_name (str): The relative graph_name of the watched node. Default: None. | |||
| params (dict): Params for create watchpoint. | |||
| - watch_condition (dict): The watch condition. The format is like: | |||
| { | |||
| "id": "tensor_too_large", | |||
| "params": [ | |||
| { | |||
| "name": "abs_mean_gt", | |||
| "disable": false, | |||
| "value": 1.1 | |||
| } | |||
| ] | |||
| } | |||
| - id (str): Id of condition. | |||
| - params (list[dict]): The list of param for this condition. | |||
| - watch_nodes (list[str]): The list of node names. | |||
| - watch_point_id (int): The id of watchpoint. | |||
| - search_pattern (dict): The search pattern. | |||
| - graph_name (str): The relative graph_name of the watched node. | |||
| Returns: | |||
| dict, the id of new watchpoint and metadata info. | |||
| """ | |||
| log.info("Received create watchpoint request. WatchCondition: %s", watch_condition) | |||
| metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | |||
| if metadata_stream.state != ServerStatus.WAITING.value: | |||
| log.error("Failed to create watchpoint as the MindSpore is not in waiting state.") | |||
| raise DebuggerCreateWatchPointError( | |||
| "Failed to create watchpoint as the MindSpore is not in waiting state.") | |||
| if metadata_stream.backend == 'GPU' and watch_condition.get('id') in ( | |||
| ConditionIdEnum.OVERFLOW_ASCEND_CHIP.value, ConditionIdEnum.OPERATOR_OVERFLOW.value): | |||
| log.error("GPU doesn't support overflow watch condition.") | |||
| raise DebuggerParamValueError("GPU doesn't support overflow watch condition.") | |||
| if metadata_stream.backend == 'Ascend' and watch_condition.get('id') == ConditionIdEnum.NAN.value: | |||
| log.error("Ascend doesn't support nan watch condition.") | |||
| raise DebuggerParamValueError("Ascend doesn't support nan watch condition.") | |||
| watch_nodes = self._get_watch_node_with_basic_info( | |||
| node_names=watch_nodes, search_pattern=search_pattern, graph_name=graph_name) | |||
| watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) | |||
| watch_point_id = watchpoint_stream.create_watchpoint( | |||
| self.condition_mgr, watch_condition, watch_nodes, watch_point_id) | |||
| log.info("Create watchpoint %d", watch_point_id) | |||
| watchpoint_opt = WatchpointOperator(self.cache_store, self.condition_mgr) | |||
| return watchpoint_opt.create_watchpoint(params) | |||
| metadata_stream.enable_recheck = watchpoint_stream.is_recheckable(metadata_stream.backend) | |||
| res = metadata_stream.get(['state', 'enable_recheck']) | |||
| res['id'] = watch_point_id | |||
| return res | |||
| def update_watchpoint(self, watch_point_id, watch_nodes, mode, search_pattern=None, graph_name=None): | |||
| def update_watchpoint(self, params): | |||
| """ | |||
| Update watchpoint. | |||
| Args: | |||
| watch_point_id (int): The id of watchpoint. | |||
| watch_nodes (list[str]): The list of node names. | |||
| mode (int): The update operator on nodes. 0 for remove nodes from watch nodes. | |||
| 1 for add nodes to watch nodes. | |||
| search_pattern (dict): The search pattern. Default: None. | |||
| graph_name (str): The relative graph_name of the watched node. Default: None. | |||
| params (dict): Params for update watchpoint. | |||
| Returns: | |||
| dict, the metadata info. | |||
| """ | |||
| metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | |||
| if metadata_stream.state != ServerStatus.WAITING.value: | |||
| log.error("Failed to update watchpoint as the MindSpore is not in waiting state.") | |||
| raise DebuggerUpdateWatchPointError( | |||
| "Failed to update watchpoint as the MindSpore is not in waiting state." | |||
| ) | |||
| # validate parameter | |||
| watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) | |||
| watchpoint_stream.validate_watchpoint_id(watch_point_id) | |||
| if not watch_nodes or not watch_point_id: | |||
| log.error("Invalid parameter for update watchpoint.") | |||
| raise DebuggerParamValueError("Invalid parameter for update watchpoint.") | |||
| # get node basic info for watch nodes | |||
| watch_nodes = self._get_watch_node_with_basic_info(watch_nodes, search_pattern, graph_name) | |||
| watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, mode) | |||
| metadata_stream.enable_recheck = watchpoint_stream.is_recheckable(metadata_stream.backend) | |||
| log.info("Update watchpoint with id: %d", watch_point_id) | |||
| return metadata_stream.get(['state', 'enable_recheck']) | |||
| def _get_watch_node_with_basic_info(self, node_names, search_pattern=None, graph_name=None): | |||
| """ | |||
| Get watch node with basic info. | |||
| Args: | |||
| node_names (list[str]): A list of node names. | |||
| search_pattern (dict): Get watch node with search pattern. Default: None | |||
| graph_name (str): The relative graph_name of the watched node. Default: None. | |||
| - watch_point_id (int): The id of watchpoint. | |||
| - watch_nodes (list[str]): The list of node names. | |||
| - mode (int): The update operator on nodes. 0 for remove nodes from watch nodes. | |||
| 1 for add nodes to watch nodes. | |||
| - search_pattern (dict): The search pattern. | |||
| - graph_name (str): The relative graph_name of the watched node. | |||
| Returns: | |||
| list[NodeBasicInfo], a list of node basic infos. | |||
| dict, the metadata info. | |||
| """ | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| graph_name = graph_stream.validate_graph_name(graph_name) | |||
| if search_pattern is not None: | |||
| watch_nodes = self._get_watch_nodes_by_search(node_names, search_pattern, graph_name) | |||
| else: | |||
| watch_nodes = self._get_node_basic_infos(node_names, graph_name=graph_name) | |||
| return watch_nodes | |||
| def _get_watch_nodes_by_search(self, watch_nodes, search_pattern, graph_name): | |||
| """Get watched leaf nodes by search name.""" | |||
| watched_leaf_nodes = [] | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| new_pattern = {'graph_name': graph_name}.update(search_pattern) | |||
| for search_name in watch_nodes: | |||
| search_nodes = graph_stream.get_searched_node_list(new_pattern) | |||
| search_node_names = [ | |||
| NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type) | |||
| for node in search_nodes | |||
| if node.name.startswith(search_name)] | |||
| watched_leaf_nodes.extend(search_node_names) | |||
| log.debug("Update nodes: %s", watched_leaf_nodes) | |||
| return watched_leaf_nodes | |||
| watchpoint_opt = WatchpointOperator(self.cache_store, self.condition_mgr) | |||
| return watchpoint_opt.update_watchpoint(params) | |||
| def delete_watchpoint(self, watch_point_id=None): | |||
| """ | |||
| @@ -625,39 +567,10 @@ class DebuggerServer: | |||
| Returns: | |||
| dict, the metadata info. | |||
| """ | |||
| metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | |||
| if metadata_stream.state != ServerStatus.WAITING.value: | |||
| log.error("Failed to delete watchpoint as the MindSpore is not in waiting state.") | |||
| raise DebuggerDeleteWatchPointError( | |||
| "Failed to delete watchpoint as the MindSpore is not in waiting state." | |||
| ) | |||
| watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) | |||
| watchpoint_stream.delete_watchpoint(watch_point_id) | |||
| metadata_stream.enable_recheck = watchpoint_stream.is_recheckable() | |||
| log.info("Delete watchpoint with id: %s", watch_point_id) | |||
| return metadata_stream.get(['state', 'enable_recheck']) | |||
| def _get_node_basic_infos(self, node_names, graph_name=None): | |||
| """ | |||
| Get node info according to node names. | |||
| Args: | |||
| node_names (list[str]): A list of node names. | |||
| graph_name (str): The relative graph_name of the watched node. Default: None. | |||
| Returns: | |||
| list[NodeBasicInfo], a list of basic node infos. | |||
| """ | |||
| if not node_names: | |||
| return [] | |||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||
| node_infos = [] | |||
| for node_name in node_names: | |||
| node_info = graph_stream.get_node_basic_info(node_name, graph_name) | |||
| node_infos.append(node_info) | |||
| return node_infos | |||
| watchpoint_opt = WatchpointOperator(self.cache_store, self.condition_mgr) | |||
| return watchpoint_opt.delete_watchpoint(watch_point_id=watch_point_id) | |||
| @try_except | |||
| def control(self, params=None): | |||
| """ | |||
| Control the training process. | |||
| @@ -678,7 +591,7 @@ class DebuggerServer: | |||
| dict, the response. | |||
| """ | |||
| log.info("Receive control request: %s.", params) | |||
| mode = params.pop('mode', None) | |||
| mode = params.pop('mode', None) if params else None | |||
| training_controller = TrainingControlOperator(self.cache_store) | |||
| training_controller.validate_mode(mode) | |||
| return training_controller.control(mode, params) | |||
| @@ -717,6 +630,7 @@ class DebuggerServer: | |||
| return reply | |||
| @try_except | |||
| def recheck(self): | |||
| """ | |||
| Recheck all watchpoints. | |||
| @@ -44,6 +44,8 @@ class WatchpointHandler(StreamHandlerBase): | |||
| self._temp_cached_node_full_names = set() | |||
| self._latest_id = 0 | |||
| self._cache_set_cmd = {} | |||
| # whether the watchpoint list has been changed since last step | |||
| self.outdated = False | |||
| def put(self, value): | |||
| """ | |||
| @@ -176,7 +178,7 @@ class WatchpointHandler(StreamHandlerBase): | |||
| Returns: | |||
| bool, if enable to recheck. | |||
| """ | |||
| enable_recheck = bool(self._updated_watchpoints or self._deleted_watchpoints) | |||
| enable_recheck = self.outdated | |||
| if backend == 'GPU' and enable_recheck: | |||
| # on GPU, disable to recheck if there are new watched node of which the tensor | |||
| # has not been stored on MindSpore | |||
| @@ -278,7 +280,7 @@ class WatchpointHandler(StreamHandlerBase): | |||
| self.validate_watchpoint_id(watch_point_id) | |||
| watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id)) | |||
| self.put(watchpoint) | |||
| self.outdated = True | |||
| return new_id | |||
| def update_watchpoint(self, watch_point_id, watch_nodes, watched=False): | |||
| @@ -300,6 +302,7 @@ class WatchpointHandler(StreamHandlerBase): | |||
| watchpoint.remove_nodes(watch_nodes) | |||
| self._remove_watch_node_from_cache(watch_nodes) | |||
| self._updated_watchpoints[watch_point_id] = watchpoint | |||
| self.outdated = True | |||
| log.debug("Update watchpoint %d in cache.", watch_point_id) | |||
| def delete_watchpoint(self, watch_point_id=None): | |||
| @@ -317,6 +320,7 @@ class WatchpointHandler(StreamHandlerBase): | |||
| watch_point_ids = [watch_point_id] | |||
| for single_id in watch_point_ids: | |||
| self._delete_single_watchpoint(single_id) | |||
| self.outdated = True | |||
| def _delete_single_watchpoint(self, watch_point_id): | |||
| """ | |||
| @@ -91,7 +91,6 @@ class TrainingControlOperator: | |||
| """ | |||
| metadata_stream = self._metadata_stream | |||
| if metadata_stream.state != ServerStatus.WAITING.value: | |||
| self._cache_store.put_data(metadata_stream.get()) | |||
| log.error("MindSpore is not ready to run. Current state is: %s", metadata_stream.state) | |||
| raise DebuggerContinueError( | |||
| "MindSpore is not ready to run or is running currently." | |||
| @@ -214,7 +213,6 @@ class TrainingControlOperator: | |||
| """ | |||
| metadata_stream = self._metadata_stream | |||
| if metadata_stream.state != ServerStatus.RUNNING.value: | |||
| self._cache_store.put_data(metadata_stream.get()) | |||
| log.error("The MindSpore is not running.") | |||
| raise DebuggerPauseError("The MindSpore is not running.") | |||
| metadata_stream.state = 'waiting' | |||
| @@ -0,0 +1,216 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """This module is aimed to deal with watchpoint commands.""" | |||
| from mindinsight.conditionmgr.common.utils import NodeBasicInfo | |||
| from mindinsight.conditionmgr.condition import ConditionIdEnum | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | |||
| DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \ | |||
| DebuggerDeleteWatchPointError | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.common.utils import ServerStatus, \ | |||
| Streams | |||
| class WatchpointOperator: | |||
| """Watchpoint Operator.""" | |||
| def __init__(self, cache_store, condition_mgr): | |||
| self._watchpoint_stream = cache_store.get_stream_handler(Streams.WATCHPOINT) | |||
| self._graph_stream = cache_store.get_stream_handler(Streams.GRAPH) | |||
| self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA) | |||
| self._condition_mgr = condition_mgr | |||
| def create_watchpoint(self, params): | |||
| """ | |||
| Create watchpoint. | |||
| Args: | |||
| - watch_condition (dict): The watch condition. The format is like: | |||
| { | |||
| "id": "tensor_too_large", | |||
| "params": [ | |||
| { | |||
| "name": "abs_mean_gt", | |||
| "disable": false, | |||
| "value": 1.1 | |||
| } | |||
| ] | |||
| } | |||
| - id (str): Id of condition. | |||
| - params (list[dict]): The list of param for this condition. | |||
| - watch_nodes (list[str]): The list of node names. | |||
| - watch_point_id (int): The id of watchpoint. | |||
| - search_pattern (dict): The search pattern. | |||
| - graph_name (str): The relative graph_name of the watched node. | |||
| Returns: | |||
| dict, the id of new watchpoint and metadata info. | |||
| """ | |||
| watch_condition = params.get('watch_condition') | |||
| log.info("Received create watchpoint request. WatchCondition: %s", watch_condition) | |||
| metadata_stream = self._metadata_stream | |||
| if metadata_stream.state != ServerStatus.WAITING.value: | |||
| log.error("Failed to create watchpoint as the MindSpore is not in waiting state.") | |||
| raise DebuggerCreateWatchPointError( | |||
| "Failed to create watchpoint as the MindSpore is not in waiting state.") | |||
| self._validate_watch_condition(watch_condition) | |||
| watch_nodes = self._get_watch_node_with_basic_info( | |||
| node_names=params.get('watch_nodes'), | |||
| search_pattern=params.get('search_pattern'), | |||
| graph_name=params.get('graph_name')) | |||
| watchpoint_stream = self._watchpoint_stream | |||
| watch_point_id = watchpoint_stream.create_watchpoint( | |||
| self._condition_mgr, watch_condition, watch_nodes, params.get('watch_point_id')) | |||
| log.info("Create watchpoint %d", watch_point_id) | |||
| metadata_stream.enable_recheck = watchpoint_stream.is_recheckable(metadata_stream.backend) | |||
| res = metadata_stream.get(['state', 'enable_recheck']) | |||
| res['id'] = watch_point_id | |||
| return res | |||
| def _validate_watch_condition(self, watch_condition): | |||
| """Validate watch condition.""" | |||
| metadata_stream = self._metadata_stream | |||
| if metadata_stream.backend == 'GPU' and watch_condition.get('id') in ( | |||
| ConditionIdEnum.OVERFLOW_ASCEND_CHIP.value, ConditionIdEnum.OPERATOR_OVERFLOW.value): | |||
| log.error("GPU doesn't support overflow watch condition.") | |||
| raise DebuggerParamValueError("GPU doesn't support overflow watch condition.") | |||
| if metadata_stream.backend == 'Ascend' and watch_condition.get('id') == ConditionIdEnum.NAN.value: | |||
| log.error("Ascend doesn't support nan watch condition.") | |||
| raise DebuggerParamValueError("Ascend doesn't support nan watch condition.") | |||
| def update_watchpoint(self, params): | |||
| """ | |||
| Update watchpoint. | |||
| Args: | |||
| params (dict): Params for update watchpoint. | |||
| - watch_point_id (int): The id of watchpoint. | |||
| - watch_nodes (list[str]): The list of node names. | |||
| - mode (int): The update operator on nodes. 0 for remove nodes from watch nodes. | |||
| 1 for add nodes to watch nodes. | |||
| - search_pattern (dict): The search pattern. | |||
| - graph_name (str): The relative graph_name of the watched node. | |||
| Returns: | |||
| dict, the metadata info. | |||
| """ | |||
| metadata_stream = self._metadata_stream | |||
| if metadata_stream.state != ServerStatus.WAITING.value: | |||
| log.error("Failed to update watchpoint as the MindSpore is not in waiting state.") | |||
| raise DebuggerUpdateWatchPointError( | |||
| "Failed to update watchpoint as the MindSpore is not in waiting state." | |||
| ) | |||
| # validate parameter | |||
| watchpoint_stream = self._watchpoint_stream | |||
| watch_point_id = params.get('watch_point_id') | |||
| watch_nodes = params.get('watch_nodes') | |||
| if not watch_nodes or not watch_point_id: | |||
| log.error("Invalid parameter for update watchpoint.") | |||
| raise DebuggerParamValueError("Invalid parameter for update watchpoint.") | |||
| watchpoint_stream.validate_watchpoint_id(watch_point_id) | |||
| # get node basic info for watch nodes | |||
| watch_nodes = self._get_watch_node_with_basic_info( | |||
| node_names=params.get('watch_nodes'), | |||
| search_pattern=params.get('search_pattern'), | |||
| graph_name=params.get('graph_name')) | |||
| watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, params.get('mode')) | |||
| metadata_stream.enable_recheck = watchpoint_stream.is_recheckable(metadata_stream.backend) | |||
| log.info("Update watchpoint with id: %d", watch_point_id) | |||
| return metadata_stream.get(['state', 'enable_recheck']) | |||
| def _get_watch_node_with_basic_info(self, node_names, search_pattern=None, graph_name=None): | |||
| """ | |||
| Get watch node with basic info. | |||
| Args: | |||
| node_names (list[str]): A list of node names. | |||
| search_pattern (dict): Get watch node with search pattern. Default: None | |||
| graph_name (str): The relative graph_name of the watched node. Default: None. | |||
| Returns: | |||
| list[NodeBasicInfo], a list of node basic infos. | |||
| """ | |||
| graph_name = self._graph_stream.validate_graph_name(graph_name) | |||
| if search_pattern is not None: | |||
| watch_nodes = self._get_watch_nodes_by_search(node_names, search_pattern, graph_name) | |||
| else: | |||
| watch_nodes = self._get_node_basic_infos(node_names, graph_name=graph_name) | |||
| return watch_nodes | |||
| def _get_watch_nodes_by_search(self, watch_nodes, search_pattern, graph_name): | |||
| """Get watched leaf nodes by search name.""" | |||
| watched_leaf_nodes = [] | |||
| graph_stream = self._graph_stream | |||
| new_pattern = {'graph_name': graph_name}.update(search_pattern) | |||
| for search_name in watch_nodes: | |||
| search_nodes = graph_stream.get_searched_node_list(new_pattern) | |||
| search_node_names = [ | |||
| NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type) | |||
| for node in search_nodes | |||
| if node.name.startswith(search_name)] | |||
| watched_leaf_nodes.extend(search_node_names) | |||
| log.debug("Update nodes: %s", watched_leaf_nodes) | |||
| return watched_leaf_nodes | |||
| def delete_watchpoint(self, watch_point_id=None): | |||
| """ | |||
| Delete watchpoint. | |||
| Args: | |||
| watch_point_id (Union[None, int]): The id of watchpoint. | |||
| If None, delete all watchpoints. Default: None. | |||
| Returns: | |||
| dict, the metadata info. | |||
| """ | |||
| metadata_stream = self._metadata_stream | |||
| if metadata_stream.state != ServerStatus.WAITING.value: | |||
| log.error("Failed to delete watchpoint as the MindSpore is not in waiting state.") | |||
| raise DebuggerDeleteWatchPointError( | |||
| "Failed to delete watchpoint as the MindSpore is not in waiting state." | |||
| ) | |||
| watchpoint_stream = self._watchpoint_stream | |||
| watchpoint_stream.delete_watchpoint(watch_point_id) | |||
| metadata_stream.enable_recheck = watchpoint_stream.is_recheckable() | |||
| log.info("Delete watchpoint with id: %s", watch_point_id) | |||
| return metadata_stream.get(['state', 'enable_recheck']) | |||
| def _get_node_basic_infos(self, node_names, graph_name=None): | |||
| """ | |||
| Get node info according to node names. | |||
| Args: | |||
| node_names (list[str]): A list of node names. | |||
| graph_name (str): The relative graph_name of the watched node. Default: None. | |||
| Returns: | |||
| list[NodeBasicInfo], a list of basic node infos. | |||
| """ | |||
| if not node_names: | |||
| return [] | |||
| graph_stream = self._graph_stream | |||
| node_infos = [] | |||
| for node_name in node_names: | |||
| node_info = graph_stream.get_node_basic_info(node_name, graph_name) | |||
| node_infos.append(node_info) | |||
| return node_infos | |||
| @@ -190,7 +190,7 @@ class TestDebuggerServer: | |||
| def test_create_watchpoint_with_wrong_state(self): | |||
| """Test create watchpoint with wrong state.""" | |||
| with pytest.raises(DebuggerCreateWatchPointError, match='Failed to create watchpoint'): | |||
| self._server.create_watchpoint(watch_condition={'condition': 'INF'}) | |||
| self._server.create_watchpoint({'watch_condition': {'condition': 'INF'}}) | |||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | |||
| @mock.patch.object(GraphHandler, 'get_node_basic_info', return_value=[MagicMock()]) | |||
| @@ -199,7 +199,8 @@ class TestDebuggerServer: | |||
| def test_create_watchpoint(self, *args): | |||
| """Test create watchpoint.""" | |||
| args[0].return_value = 1 | |||
| res = self._server.create_watchpoint({'condition': 'INF'}, ['watch_node_name']) | |||
| res = self._server.create_watchpoint({'watch_condition': {'condition': 'INF'}, | |||
| 'watch_nodes': ['watch_node_name']}) | |||
| assert res == {'id': 1, 'metadata': {'enable_recheck': False, 'state': 'waiting'}} | |||
| @mock.patch.object(MetadataHandler, 'state', 'waiting') | |||
| @@ -211,8 +212,11 @@ class TestDebuggerServer: | |||
| """Test update watchpoint.""" | |||
| args[2].return_value = [MagicMock(name='search_name/op_name')] | |||
| res = self._server.update_watchpoint( | |||
| watch_point_id=1, watch_nodes=['search_name'], | |||
| mode=1, search_pattern={'name': 'search_name'}, graph_name='kernel_graph_0') | |||
| {'watch_point_id': 1, | |||
| 'watch_nodes': ['search_name'], | |||
| 'mode': 1, | |||
| 'search_pattern': {'name': 'search_name'}, | |||
| 'graph_name': 'kernel_graph_0'}) | |||
| assert res == {'metadata': {'enable_recheck': False, 'state': 'waiting'}} | |||
| def test_delete_watchpoint_with_wrong_state(self): | |||