Browse Source

refactor watchpoint

tags/v1.1.0
yelihua 5 years ago
parent
commit
bfb57114b0
6 changed files with 295 additions and 171 deletions
  1. +5
    -17
      mindinsight/backend/debugger/debugger_api.py
  2. +60
    -146
      mindinsight/debugger/debugger_server.py
  3. +6
    -2
      mindinsight/debugger/stream_handler/watchpoint_handler.py
  4. +0
    -2
      mindinsight/debugger/stream_operator/training_control_operator.py
  5. +216
    -0
      mindinsight/debugger/stream_operator/watchpoint_operator.py
  6. +8
    -4
      tests/ut/debugger/test_debugger_server.py

+ 5
- 17
mindinsight/backend/debugger/debugger_api.py View File

@@ -213,15 +213,9 @@ def create_watchpoint():
Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/create_watchpoint
"""
body = _read_post_request(request)

condition = body.get('condition')
graph_name = body.get('graph_name')
watch_nodes = body.get('watch_nodes')
watch_point_id = body.get('watch_point_id')
search_pattern = body.get('search_pattern')
reply = _wrap_reply(BACKEND_SERVER.create_watchpoint,
condition, watch_nodes, watch_point_id, search_pattern, graph_name)
params = _read_post_request(request)
params['watch_condition'] = params.pop('condition', None)
reply = _wrap_reply(BACKEND_SERVER.create_watchpoint, params)
return reply


@@ -239,14 +233,8 @@ def update_watchpoint():
Examples:
>>> POST http://xxxx/v1/mindinsight/debugger/update_watchpoint
"""
body = _read_post_request(request)

watch_point_id = body.get('watch_point_id')
watch_nodes = body.get('watch_nodes')
graph_name = body.get('graph_name')
mode = body.get('mode')
pattern = body.get('search_pattern')
reply = _wrap_reply(BACKEND_SERVER.update_watchpoint, watch_point_id, watch_nodes, mode, pattern, graph_name)
params = _read_post_request(request)
reply = _wrap_reply(BACKEND_SERVER.update_watchpoint, params)
return reply




+ 60
- 146
mindinsight/debugger/debugger_server.py View File

@@ -15,21 +15,20 @@
"""Implement the debugger server."""
import signal
from concurrent import futures
from functools import wraps
from threading import Thread

import grpc

from mindinsight.conditionmgr.common.utils import NodeBasicInfo
from mindinsight.conditionmgr.condition import ConditionContext, ConditionIdEnum
from mindinsight.conditionmgr.condition import ConditionContext
from mindinsight.conditionmgr.conditionmgr import ConditionMgr
from mindinsight.conditionmgr.recommender import recommend_watchpoints
from mindinsight.conf import settings
from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
from mindinsight.datavisual.utils.tools import to_float
from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \
DebuggerDeleteWatchPointError, DebuggerCompareTensorError, DebuggerTensorGraphError, \
DebuggerTensorHitError
DebuggerParamTypeError, DebuggerCompareTensorError, DebuggerTensorGraphError, \
DebuggerTensorHitError, MindInsightException
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import ServerStatus, \
create_view_event_from_tensor_basic_info, Streams
@@ -38,9 +37,26 @@ from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer
from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
from mindinsight.debugger.stream_operator.tensor_detail_info import TensorDetailInfo
from mindinsight.debugger.stream_operator.training_control_operator import TrainingControlOperator
from mindinsight.debugger.stream_operator.watchpoint_operator import WatchpointOperator
from mindinsight.utils.tensor import TensorUtils, MAX_DIMENSIONS_FOR_TENSOR


def try_except(func):
"""Send latest metadata when catch exception."""

@wraps(func)
def send_latest_metadata(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except MindInsightException as err:
metadata = self.cache_store.get_stream_handler(Streams.METADATA).get()
self.cache_store.put_data(metadata)
log.info("Put latest metadata into data-queue.")
raise err

return send_latest_metadata


class DebuggerServer:
"""The server manager of debugger."""

@@ -488,131 +504,57 @@ class DebuggerServer:

return reply

def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None, search_pattern=None,
graph_name=None):
def create_watchpoint(self, params):
"""
Create watchpoint.

Args:
watch_condition (dict): The watch condition. The format is like:
{
"id": "tensor_too_large",
"params": [
{
"name": "abs_mean_gt",
"disable": false,
"value": 1.1
}
]
}

- id (str): Id of condition.
- params (list[dict]): The list of param for this condition.
watch_nodes (list[str]): The list of node names.
watch_point_id (int): The id of watchpoint.
search_pattern (dict): The search pattern. Default: None.
graph_name (str): The relative graph_name of the watched node. Default: None.
params (dict): Params for create watchpoint.

- watch_condition (dict): The watch condition. The format is like:
{
"id": "tensor_too_large",
"params": [
{
"name": "abs_mean_gt",
"disable": false,
"value": 1.1
}
]
}

- id (str): Id of condition.
- params (list[dict]): The list of param for this condition.
- watch_nodes (list[str]): The list of node names.
- watch_point_id (int): The id of watchpoint.
- search_pattern (dict): The search pattern.
- graph_name (str): The relative graph_name of the watched node.

Returns:
dict, the id of new watchpoint and metadata info.
"""
log.info("Received create watchpoint request. WatchCondition: %s", watch_condition)
metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
if metadata_stream.state != ServerStatus.WAITING.value:
log.error("Failed to create watchpoint as the MindSpore is not in waiting state.")
raise DebuggerCreateWatchPointError(
"Failed to create watchpoint as the MindSpore is not in waiting state.")
if metadata_stream.backend == 'GPU' and watch_condition.get('id') in (
ConditionIdEnum.OVERFLOW_ASCEND_CHIP.value, ConditionIdEnum.OPERATOR_OVERFLOW.value):
log.error("GPU doesn't support overflow watch condition.")
raise DebuggerParamValueError("GPU doesn't support overflow watch condition.")

if metadata_stream.backend == 'Ascend' and watch_condition.get('id') == ConditionIdEnum.NAN.value:
log.error("Ascend doesn't support nan watch condition.")
raise DebuggerParamValueError("Ascend doesn't support nan watch condition.")

watch_nodes = self._get_watch_node_with_basic_info(
node_names=watch_nodes, search_pattern=search_pattern, graph_name=graph_name)
watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
watch_point_id = watchpoint_stream.create_watchpoint(
self.condition_mgr, watch_condition, watch_nodes, watch_point_id)
log.info("Create watchpoint %d", watch_point_id)
watchpoint_opt = WatchpointOperator(self.cache_store, self.condition_mgr)
return watchpoint_opt.create_watchpoint(params)

metadata_stream.enable_recheck = watchpoint_stream.is_recheckable(metadata_stream.backend)
res = metadata_stream.get(['state', 'enable_recheck'])
res['id'] = watch_point_id
return res

def update_watchpoint(self, watch_point_id, watch_nodes, mode, search_pattern=None, graph_name=None):
def update_watchpoint(self, params):
"""
Update watchpoint.

Args:
watch_point_id (int): The id of watchpoint.
watch_nodes (list[str]): The list of node names.
mode (int): The update operator on nodes. 0 for remove nodes from watch nodes.
1 for add nodes to watch nodes.
search_pattern (dict): The search pattern. Default: None.
graph_name (str): The relative graph_name of the watched node. Default: None.
params (dict): Params for update watchpoint.

Returns:
dict, the metadata info.
"""
metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
if metadata_stream.state != ServerStatus.WAITING.value:
log.error("Failed to update watchpoint as the MindSpore is not in waiting state.")
raise DebuggerUpdateWatchPointError(
"Failed to update watchpoint as the MindSpore is not in waiting state."
)
# validate parameter
watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
watchpoint_stream.validate_watchpoint_id(watch_point_id)
if not watch_nodes or not watch_point_id:
log.error("Invalid parameter for update watchpoint.")
raise DebuggerParamValueError("Invalid parameter for update watchpoint.")
# get node basic info for watch nodes
watch_nodes = self._get_watch_node_with_basic_info(watch_nodes, search_pattern, graph_name)
watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, mode)
metadata_stream.enable_recheck = watchpoint_stream.is_recheckable(metadata_stream.backend)
log.info("Update watchpoint with id: %d", watch_point_id)
return metadata_stream.get(['state', 'enable_recheck'])

def _get_watch_node_with_basic_info(self, node_names, search_pattern=None, graph_name=None):
"""
Get watch node with basic info.

Args:
node_names (list[str]): A list of node names.
search_pattern (dict): Get watch node with search pattern. Default: None
graph_name (str): The relative graph_name of the watched node. Default: None.
- watch_point_id (int): The id of watchpoint.
- watch_nodes (list[str]): The list of node names.
- mode (int): The update operator on nodes. 0 for remove nodes from watch nodes.
1 for add nodes to watch nodes.
- search_pattern (dict): The search pattern.
- graph_name (str): The relative graph_name of the watched node.

Returns:
list[NodeBasicInfo], a list of node basic infos.
dict, the metadata info.
"""
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
graph_name = graph_stream.validate_graph_name(graph_name)
if search_pattern is not None:
watch_nodes = self._get_watch_nodes_by_search(node_names, search_pattern, graph_name)
else:
watch_nodes = self._get_node_basic_infos(node_names, graph_name=graph_name)
return watch_nodes

def _get_watch_nodes_by_search(self, watch_nodes, search_pattern, graph_name):
"""Get watched leaf nodes by search name."""
watched_leaf_nodes = []
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
new_pattern = {'graph_name': graph_name}.update(search_pattern)
for search_name in watch_nodes:
search_nodes = graph_stream.get_searched_node_list(new_pattern)
search_node_names = [
NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type)
for node in search_nodes
if node.name.startswith(search_name)]
watched_leaf_nodes.extend(search_node_names)

log.debug("Update nodes: %s", watched_leaf_nodes)

return watched_leaf_nodes
watchpoint_opt = WatchpointOperator(self.cache_store, self.condition_mgr)
return watchpoint_opt.update_watchpoint(params)

def delete_watchpoint(self, watch_point_id=None):
"""
@@ -625,39 +567,10 @@ class DebuggerServer:
Returns:
dict, the metadata info.
"""
metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA)
if metadata_stream.state != ServerStatus.WAITING.value:
log.error("Failed to delete watchpoint as the MindSpore is not in waiting state.")
raise DebuggerDeleteWatchPointError(
"Failed to delete watchpoint as the MindSpore is not in waiting state."
)
watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT)
watchpoint_stream.delete_watchpoint(watch_point_id)
metadata_stream.enable_recheck = watchpoint_stream.is_recheckable()
log.info("Delete watchpoint with id: %s", watch_point_id)
return metadata_stream.get(['state', 'enable_recheck'])

def _get_node_basic_infos(self, node_names, graph_name=None):
"""
Get node info according to node names.

Args:
node_names (list[str]): A list of node names.
graph_name (str): The relative graph_name of the watched node. Default: None.

Returns:
list[NodeBasicInfo], a list of basic node infos.
"""
if not node_names:
return []
graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH)
node_infos = []
for node_name in node_names:
node_info = graph_stream.get_node_basic_info(node_name, graph_name)
node_infos.append(node_info)

return node_infos
watchpoint_opt = WatchpointOperator(self.cache_store, self.condition_mgr)
return watchpoint_opt.delete_watchpoint(watch_point_id=watch_point_id)

@try_except
def control(self, params=None):
"""
Control the training process.
@@ -678,7 +591,7 @@ class DebuggerServer:
dict, the response.
"""
log.info("Receive control request: %s.", params)
mode = params.pop('mode', None)
mode = params.pop('mode', None) if params else None
training_controller = TrainingControlOperator(self.cache_store)
training_controller.validate_mode(mode)
return training_controller.control(mode, params)
@@ -717,6 +630,7 @@ class DebuggerServer:

return reply

@try_except
def recheck(self):
"""
Recheck all watchpoints.


+ 6
- 2
mindinsight/debugger/stream_handler/watchpoint_handler.py View File

@@ -44,6 +44,8 @@ class WatchpointHandler(StreamHandlerBase):
self._temp_cached_node_full_names = set()
self._latest_id = 0
self._cache_set_cmd = {}
# whether the watchpoint list has been changed since last step
self.outdated = False

def put(self, value):
"""
@@ -176,7 +178,7 @@ class WatchpointHandler(StreamHandlerBase):
Returns:
bool, if enable to recheck.
"""
enable_recheck = bool(self._updated_watchpoints or self._deleted_watchpoints)
enable_recheck = self.outdated
if backend == 'GPU' and enable_recheck:
# on GPU, disable to recheck if there are new watched node of which the tensor
# has not been stored on MindSpore
@@ -278,7 +280,7 @@ class WatchpointHandler(StreamHandlerBase):
self.validate_watchpoint_id(watch_point_id)
watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id))
self.put(watchpoint)
self.outdated = True
return new_id

def update_watchpoint(self, watch_point_id, watch_nodes, watched=False):
@@ -300,6 +302,7 @@ class WatchpointHandler(StreamHandlerBase):
watchpoint.remove_nodes(watch_nodes)
self._remove_watch_node_from_cache(watch_nodes)
self._updated_watchpoints[watch_point_id] = watchpoint
self.outdated = True
log.debug("Update watchpoint %d in cache.", watch_point_id)

def delete_watchpoint(self, watch_point_id=None):
@@ -317,6 +320,7 @@ class WatchpointHandler(StreamHandlerBase):
watch_point_ids = [watch_point_id]
for single_id in watch_point_ids:
self._delete_single_watchpoint(single_id)
self.outdated = True

def _delete_single_watchpoint(self, watch_point_id):
"""


+ 0
- 2
mindinsight/debugger/stream_operator/training_control_operator.py View File

@@ -91,7 +91,6 @@ class TrainingControlOperator:
"""
metadata_stream = self._metadata_stream
if metadata_stream.state != ServerStatus.WAITING.value:
self._cache_store.put_data(metadata_stream.get())
log.error("MindSpore is not ready to run. Current state is: %s", metadata_stream.state)
raise DebuggerContinueError(
"MindSpore is not ready to run or is running currently."
@@ -214,7 +213,6 @@ class TrainingControlOperator:
"""
metadata_stream = self._metadata_stream
if metadata_stream.state != ServerStatus.RUNNING.value:
self._cache_store.put_data(metadata_stream.get())
log.error("The MindSpore is not running.")
raise DebuggerPauseError("The MindSpore is not running.")
metadata_stream.state = 'waiting'


+ 216
- 0
mindinsight/debugger/stream_operator/watchpoint_operator.py View File

@@ -0,0 +1,216 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""This module is aimed to deal with watchpoint commands."""

from mindinsight.conditionmgr.common.utils import NodeBasicInfo
from mindinsight.conditionmgr.condition import ConditionIdEnum
from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \
DebuggerDeleteWatchPointError
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import ServerStatus, \
Streams


class WatchpointOperator:
"""Watchpoint Operator."""

def __init__(self, cache_store, condition_mgr):
self._watchpoint_stream = cache_store.get_stream_handler(Streams.WATCHPOINT)
self._graph_stream = cache_store.get_stream_handler(Streams.GRAPH)
self._metadata_stream = cache_store.get_stream_handler(Streams.METADATA)
self._condition_mgr = condition_mgr

def create_watchpoint(self, params):
"""
Create watchpoint.

Args:
- watch_condition (dict): The watch condition. The format is like:
{
"id": "tensor_too_large",
"params": [
{
"name": "abs_mean_gt",
"disable": false,
"value": 1.1
}
]
}

- id (str): Id of condition.
- params (list[dict]): The list of param for this condition.
- watch_nodes (list[str]): The list of node names.
- watch_point_id (int): The id of watchpoint.
- search_pattern (dict): The search pattern.
- graph_name (str): The relative graph_name of the watched node.

Returns:
dict, the id of new watchpoint and metadata info.
"""
watch_condition = params.get('watch_condition')
log.info("Received create watchpoint request. WatchCondition: %s", watch_condition)
metadata_stream = self._metadata_stream
if metadata_stream.state != ServerStatus.WAITING.value:
log.error("Failed to create watchpoint as the MindSpore is not in waiting state.")
raise DebuggerCreateWatchPointError(
"Failed to create watchpoint as the MindSpore is not in waiting state.")
self._validate_watch_condition(watch_condition)

watch_nodes = self._get_watch_node_with_basic_info(
node_names=params.get('watch_nodes'),
search_pattern=params.get('search_pattern'),
graph_name=params.get('graph_name'))
watchpoint_stream = self._watchpoint_stream
watch_point_id = watchpoint_stream.create_watchpoint(
self._condition_mgr, watch_condition, watch_nodes, params.get('watch_point_id'))
log.info("Create watchpoint %d", watch_point_id)

metadata_stream.enable_recheck = watchpoint_stream.is_recheckable(metadata_stream.backend)
res = metadata_stream.get(['state', 'enable_recheck'])
res['id'] = watch_point_id
return res

def _validate_watch_condition(self, watch_condition):
"""Validate watch condition."""
metadata_stream = self._metadata_stream
if metadata_stream.backend == 'GPU' and watch_condition.get('id') in (
ConditionIdEnum.OVERFLOW_ASCEND_CHIP.value, ConditionIdEnum.OPERATOR_OVERFLOW.value):
log.error("GPU doesn't support overflow watch condition.")
raise DebuggerParamValueError("GPU doesn't support overflow watch condition.")
if metadata_stream.backend == 'Ascend' and watch_condition.get('id') == ConditionIdEnum.NAN.value:
log.error("Ascend doesn't support nan watch condition.")
raise DebuggerParamValueError("Ascend doesn't support nan watch condition.")

def update_watchpoint(self, params):
"""
Update watchpoint.

Args:
params (dict): Params for update watchpoint.

- watch_point_id (int): The id of watchpoint.
- watch_nodes (list[str]): The list of node names.
- mode (int): The update operator on nodes. 0 for remove nodes from watch nodes.
1 for add nodes to watch nodes.
- search_pattern (dict): The search pattern.
- graph_name (str): The relative graph_name of the watched node.

Returns:
dict, the metadata info.
"""
metadata_stream = self._metadata_stream
if metadata_stream.state != ServerStatus.WAITING.value:
log.error("Failed to update watchpoint as the MindSpore is not in waiting state.")
raise DebuggerUpdateWatchPointError(
"Failed to update watchpoint as the MindSpore is not in waiting state."
)
# validate parameter
watchpoint_stream = self._watchpoint_stream
watch_point_id = params.get('watch_point_id')
watch_nodes = params.get('watch_nodes')
if not watch_nodes or not watch_point_id:
log.error("Invalid parameter for update watchpoint.")
raise DebuggerParamValueError("Invalid parameter for update watchpoint.")
watchpoint_stream.validate_watchpoint_id(watch_point_id)
# get node basic info for watch nodes
watch_nodes = self._get_watch_node_with_basic_info(
node_names=params.get('watch_nodes'),
search_pattern=params.get('search_pattern'),
graph_name=params.get('graph_name'))
watchpoint_stream.update_watchpoint(watch_point_id, watch_nodes, params.get('mode'))
metadata_stream.enable_recheck = watchpoint_stream.is_recheckable(metadata_stream.backend)
log.info("Update watchpoint with id: %d", watch_point_id)
return metadata_stream.get(['state', 'enable_recheck'])

def _get_watch_node_with_basic_info(self, node_names, search_pattern=None, graph_name=None):
"""
Get watch node with basic info.

Args:
node_names (list[str]): A list of node names.
search_pattern (dict): Get watch node with search pattern. Default: None
graph_name (str): The relative graph_name of the watched node. Default: None.

Returns:
list[NodeBasicInfo], a list of node basic infos.
"""
graph_name = self._graph_stream.validate_graph_name(graph_name)
if search_pattern is not None:
watch_nodes = self._get_watch_nodes_by_search(node_names, search_pattern, graph_name)
else:
watch_nodes = self._get_node_basic_infos(node_names, graph_name=graph_name)
return watch_nodes

def _get_watch_nodes_by_search(self, watch_nodes, search_pattern, graph_name):
"""Get watched leaf nodes by search name."""
watched_leaf_nodes = []
graph_stream = self._graph_stream
new_pattern = {'graph_name': graph_name}.update(search_pattern)
for search_name in watch_nodes:
search_nodes = graph_stream.get_searched_node_list(new_pattern)
search_node_names = [
NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type)
for node in search_nodes
if node.name.startswith(search_name)]
watched_leaf_nodes.extend(search_node_names)

log.debug("Update nodes: %s", watched_leaf_nodes)

return watched_leaf_nodes

def delete_watchpoint(self, watch_point_id=None):
"""
Delete watchpoint.

Args:
watch_point_id (Union[None, int]): The id of watchpoint.
If None, delete all watchpoints. Default: None.

Returns:
dict, the metadata info.
"""
metadata_stream = self._metadata_stream
if metadata_stream.state != ServerStatus.WAITING.value:
log.error("Failed to delete watchpoint as the MindSpore is not in waiting state.")
raise DebuggerDeleteWatchPointError(
"Failed to delete watchpoint as the MindSpore is not in waiting state."
)
watchpoint_stream = self._watchpoint_stream
watchpoint_stream.delete_watchpoint(watch_point_id)
metadata_stream.enable_recheck = watchpoint_stream.is_recheckable()
log.info("Delete watchpoint with id: %s", watch_point_id)
return metadata_stream.get(['state', 'enable_recheck'])

def _get_node_basic_infos(self, node_names, graph_name=None):
"""
Get node info according to node names.

Args:
node_names (list[str]): A list of node names.
graph_name (str): The relative graph_name of the watched node. Default: None.

Returns:
list[NodeBasicInfo], a list of basic node infos.
"""
if not node_names:
return []
graph_stream = self._graph_stream
node_infos = []
for node_name in node_names:
node_info = graph_stream.get_node_basic_info(node_name, graph_name)
node_infos.append(node_info)

return node_infos

+ 8
- 4
tests/ut/debugger/test_debugger_server.py View File

@@ -190,7 +190,7 @@ class TestDebuggerServer:
def test_create_watchpoint_with_wrong_state(self):
"""Test create watchpoint with wrong state."""
with pytest.raises(DebuggerCreateWatchPointError, match='Failed to create watchpoint'):
self._server.create_watchpoint(watch_condition={'condition': 'INF'})
self._server.create_watchpoint({'watch_condition': {'condition': 'INF'}})

@mock.patch.object(MetadataHandler, 'state', 'waiting')
@mock.patch.object(GraphHandler, 'get_node_basic_info', return_value=[MagicMock()])
@@ -199,7 +199,8 @@ class TestDebuggerServer:
def test_create_watchpoint(self, *args):
"""Test create watchpoint."""
args[0].return_value = 1
res = self._server.create_watchpoint({'condition': 'INF'}, ['watch_node_name'])
res = self._server.create_watchpoint({'watch_condition': {'condition': 'INF'},
'watch_nodes': ['watch_node_name']})
assert res == {'id': 1, 'metadata': {'enable_recheck': False, 'state': 'waiting'}}

@mock.patch.object(MetadataHandler, 'state', 'waiting')
@@ -211,8 +212,11 @@ class TestDebuggerServer:
"""Test update watchpoint."""
args[2].return_value = [MagicMock(name='search_name/op_name')]
res = self._server.update_watchpoint(
watch_point_id=1, watch_nodes=['search_name'],
mode=1, search_pattern={'name': 'search_name'}, graph_name='kernel_graph_0')
{'watch_point_id': 1,
'watch_nodes': ['search_name'],
'mode': 1,
'search_pattern': {'name': 'search_name'},
'graph_name': 'kernel_graph_0'})
assert res == {'metadata': {'enable_recheck': False, 'state': 'waiting'}}

def test_delete_watchpoint_with_wrong_state(self):


Loading…
Cancel
Save