| @@ -0,0 +1,26 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Module init file.""" | |||||
| from mindinsight.backend.debugger.debugger_api import init_module as init_query_module | |||||
| def init_module(app): | |||||
| """ | |||||
| Init module entry. | |||||
| Args: | |||||
| app (Flask): A Flask instance. | |||||
| """ | |||||
| init_query_module(app) | |||||
| @@ -0,0 +1,300 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Debugger restful api.""" | |||||
| import json | |||||
| from flask import Blueprint, jsonify, request | |||||
| from mindinsight.conf import settings | |||||
| from mindinsight.debugger.debugger_server import DebuggerServer | |||||
| from mindinsight.utils.exceptions import ParamValueError | |||||
| BLUEPRINT = Blueprint("debugger", __name__, | |||||
| url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX) | |||||
| def _initialize_debugger_server(): | |||||
| """Initialize a debugger server instance.""" | |||||
| port = settings.DEBUGGER_PORT if hasattr(settings, 'DEBUGGER_PORT') else None | |||||
| enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False | |||||
| server = None | |||||
| if port and enable_debugger: | |||||
| server = DebuggerServer(port) | |||||
| return server | |||||
| def _read_post_request(post_request): | |||||
| """ | |||||
| Extract the body of post request. | |||||
| Args: | |||||
| post_request (object): The post request. | |||||
| Returns: | |||||
| dict, the deserialized body of request. | |||||
| """ | |||||
| body = post_request.stream.read() | |||||
| try: | |||||
| body = json.loads(body if body else "{}") | |||||
| except Exception: | |||||
| raise ParamValueError("Json data parse failed.") | |||||
| return body | |||||
| def _wrap_reply(func, *args, **kwargs): | |||||
| """Serialize reply.""" | |||||
| reply = func(*args, **kwargs) | |||||
| return jsonify(reply) | |||||
| @BLUEPRINT.route("/debugger/poll_data", methods=["GET"]) | |||||
| def poll_data(): | |||||
| """ | |||||
| Wait for data to be updated on UI. | |||||
| Get data from server and display the change on UI. | |||||
| Returns: | |||||
| str, the updated data. | |||||
| Examples: | |||||
| >>> Get http://xxxx/v1/mindinsight/debugger/poll_data?pos=xx | |||||
| """ | |||||
| pos = request.args.get('pos') | |||||
| reply = _wrap_reply(BACKEND_SERVER.poll_data, pos) | |||||
| return reply | |||||
| @BLUEPRINT.route("/debugger/search", methods=["GET"]) | |||||
| def search(): | |||||
| """ | |||||
| Search nodes in specified watchpoint. | |||||
| Returns: | |||||
| str, the required data. | |||||
| Examples: | |||||
| >>> Get http://xxxx/v1/mindinsight/debugger/retrive?mode=all | |||||
| """ | |||||
| name = request.args.get('name') | |||||
| watch_point_id = int(request.args.get('watch_point_id', 0)) | |||||
| reply = _wrap_reply(BACKEND_SERVER.search, name, watch_point_id) | |||||
| return reply | |||||
| @BLUEPRINT.route("/debugger/retrieve_node_by_bfs", methods=["GET"]) | |||||
| def retrieve_node_by_bfs(): | |||||
| """ | |||||
| Search node by bfs. | |||||
| Returns: | |||||
| str, the required data. | |||||
| Examples: | |||||
| >>> Get http://xxxx/v1/mindinsight/debugger/retrieve_node_by_bfs?name=node_name&ascend=true | |||||
| """ | |||||
| name = request.args.get('name') | |||||
| ascend = request.args.get('ascend', 'false') | |||||
| ascend = ascend == 'true' | |||||
| reply = _wrap_reply(BACKEND_SERVER.retrieve_node_by_bfs, name, ascend) | |||||
| return reply | |||||
| @BLUEPRINT.route("/debugger/tensor-comparisons", methods=["GET"]) | |||||
| def tensor_comparisons(): | |||||
| """ | |||||
| Get tensor comparisons. | |||||
| Returns: | |||||
| str, the required data. | |||||
| Examples: | |||||
| >>> Get http://xxxx/v1/mindinsight/debugger/tensor-comparisons? | |||||
| name=node_name&detail=data&shape=[0, 0, :, :]&tolerance=0.5 | |||||
| """ | |||||
| name = request.args.get('name') | |||||
| detail = request.args.get('detail', 'data') | |||||
| shape = request.args.get('shape') | |||||
| tolerance = request.args.get('tolerance', '0') | |||||
| reply = _wrap_reply(BACKEND_SERVER.tensor_comparisons, name, shape, detail, tolerance) | |||||
| return reply | |||||
| @BLUEPRINT.route("/debugger/retrieve", methods=["POST"]) | |||||
| def retrieve(): | |||||
| """ | |||||
| Retrieve data according to mode and params. | |||||
| Returns: | |||||
| str, the required data. | |||||
| Examples: | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/retrieve | |||||
| """ | |||||
| body = _read_post_request(request) | |||||
| mode = body.get('mode') | |||||
| params = body.get('params') | |||||
| reply = _wrap_reply(BACKEND_SERVER.retrieve, mode, params) | |||||
| return reply | |||||
| @BLUEPRINT.route("/debugger/retrieve_tensor_history", methods=["POST"]) | |||||
| def retrieve_tensor_history(): | |||||
| """ | |||||
| Retrieve data according to mode and params. | |||||
| Returns: | |||||
| str, the required data. | |||||
| Examples: | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/retrieve_tensor_history | |||||
| """ | |||||
| body = _read_post_request(request) | |||||
| name = body.get('name') | |||||
| reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_history, name) | |||||
| return reply | |||||
| @BLUEPRINT.route("/debugger/tensors", methods=["GET"]) | |||||
| def retrieve_tensor_value(): | |||||
| """ | |||||
| Retrieve tensor value according to name and shape. | |||||
| Returns: | |||||
| str, the required data. | |||||
| Examples: | |||||
| >>> GET http://xxxx/v1/mindinsight/debugger/tensors?name=node_name&detail=data&shape=[1,1,:,:] | |||||
| """ | |||||
| name = request.args.get('name') | |||||
| detail = request.args.get('detail') | |||||
| shape = request.args.get('shape') | |||||
| reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_value, name, detail, shape) | |||||
| return reply | |||||
| @BLUEPRINT.route("/debugger/create_watchpoint", methods=["POST"]) | |||||
| def create_watchpoint(): | |||||
| """ | |||||
| Create watchpoint. | |||||
| Returns: | |||||
| str, watchpoint id. | |||||
| Raises: | |||||
| MindInsightException: If method fails to be called. | |||||
| ParamValueError: If parsing json data search_condition fails. | |||||
| Examples: | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/create_watchpoint | |||||
| """ | |||||
| body = _read_post_request(request) | |||||
| condition = body.get('condition') | |||||
| watch_nodes = body.get('watch_nodes') | |||||
| watch_point_id = body.get('watch_point_id') | |||||
| reply = _wrap_reply(BACKEND_SERVER.create_watchpoint, condition, watch_nodes, watch_point_id) | |||||
| return reply | |||||
| @BLUEPRINT.route("/debugger/update_watchpoint", methods=["POST"]) | |||||
| def update_watchpoint(): | |||||
| """ | |||||
| Update watchpoint. | |||||
| Returns: | |||||
| str, reply message. | |||||
| Raises: | |||||
| MindInsightException: If method fails to be called. | |||||
| ParamValueError: If parsing json data search_condition fails. | |||||
| Examples: | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/update_watchpoint | |||||
| """ | |||||
| body = _read_post_request(request) | |||||
| watch_point_id = body.get('watch_point_id') | |||||
| watch_nodes = body.get('watch_nodes') | |||||
| mode = body.get('mode') | |||||
| name = body.get('name') | |||||
| reply = _wrap_reply(BACKEND_SERVER.update_watchpoint, watch_point_id, watch_nodes, mode, name) | |||||
| return reply | |||||
| @BLUEPRINT.route("/debugger/delete_watchpoint", methods=["POST"]) | |||||
| def delete_watchpoint(): | |||||
| """ | |||||
| delete watchpoint. | |||||
| Returns: | |||||
| str, reply message. | |||||
| Raises: | |||||
| MindInsightException: If method fails to be called. | |||||
| ParamValueError: If parsing json data search_condition fails. | |||||
| Examples: | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/delete_watchpoint | |||||
| """ | |||||
| body = _read_post_request(request) | |||||
| watch_point_id = body.get('watch_point_id') | |||||
| reply = _wrap_reply(BACKEND_SERVER.delete_watchpoint, watch_point_id) | |||||
| return reply | |||||
| @BLUEPRINT.route("/debugger/control", methods=["POST"]) | |||||
| def control(): | |||||
| """ | |||||
| Control request. | |||||
| Returns: | |||||
| str, reply message. | |||||
| Raises: | |||||
| MindInsightException: If method fails to be called. | |||||
| ParamValueError: If parsing json data search_condition fails. | |||||
| Examples: | |||||
| >>> POST http://xxxx/v1/mindinsight/debugger/control | |||||
| """ | |||||
| params = _read_post_request(request) | |||||
| reply = _wrap_reply(BACKEND_SERVER.control, params) | |||||
| return reply | |||||
| BACKEND_SERVER = _initialize_debugger_server() | |||||
| def init_module(app): | |||||
| """ | |||||
| Init module entry. | |||||
| Args: | |||||
| app (Flask): The application obj. | |||||
| """ | |||||
| app.register_blueprint(BLUEPRINT) | |||||
| if BACKEND_SERVER: | |||||
| BACKEND_SERVER.start() | |||||
| @@ -26,6 +26,12 @@ WORKSPACE = os.path.join(os.environ['HOME'], 'mindinsight') | |||||
| PORT = 8080 | PORT = 8080 | ||||
| URL_PATH_PREFIX = '' | URL_PATH_PREFIX = '' | ||||
| #################################### | |||||
| # Debugger default settings. | |||||
| #################################### | |||||
| DEBUGGER_PORT = '50051' | |||||
| ENABLE_DEBUGGER = False | |||||
| #################################### | #################################### | ||||
| # Datavisual default settings. | # Datavisual default settings. | ||||
| #################################### | #################################### | ||||
| @@ -15,128 +15,14 @@ | |||||
| """Tensor data container.""" | """Tensor data container.""" | ||||
| import numpy as np | import numpy as np | ||||
| from mindinsight.datavisual.common.log import logger | |||||
| from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket | from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket | ||||
| from mindinsight.datavisual.utils.utils import calc_histogram_bins | from mindinsight.datavisual.utils.utils import calc_histogram_bins | ||||
| from mindinsight.utils.exceptions import ParamValueError | from mindinsight.utils.exceptions import ParamValueError | ||||
| from mindinsight.utils.tensor import TensorUtils | |||||
| F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max | |||||
| MAX_TENSOR_COUNT = 10000000 | MAX_TENSOR_COUNT = 10000000 | ||||
| class Statistics: | |||||
| """Statistics data class. | |||||
| Args: | |||||
| max_value (float): max value of tensor data. | |||||
| min_value (float): min value of tensor data. | |||||
| avg_value (float): avg value of tensor data. | |||||
| count (int): total count of tensor data. | |||||
| nan_count (int): count of NAN. | |||||
| neg_inf_count (int): count of negative INF. | |||||
| pos_inf_count (int): count of positive INF. | |||||
| """ | |||||
| def __init__(self, max_value=0, min_value=0, avg_value=0, | |||||
| count=0, nan_count=0, neg_inf_count=0, pos_inf_count=0): | |||||
| self._max = max_value | |||||
| self._min = min_value | |||||
| self._avg = avg_value | |||||
| self._count = count | |||||
| self._nan_count = nan_count | |||||
| self._neg_inf_count = neg_inf_count | |||||
| self._pos_inf_count = pos_inf_count | |||||
| @property | |||||
| def max(self): | |||||
| """Get max value of tensor.""" | |||||
| return self._max | |||||
| @property | |||||
| def min(self): | |||||
| """Get min value of tensor.""" | |||||
| return self._min | |||||
| @property | |||||
| def avg(self): | |||||
| """Get avg value of tensor.""" | |||||
| return self._avg | |||||
| @property | |||||
| def count(self): | |||||
| """Get total count of tensor.""" | |||||
| return self._count | |||||
| @property | |||||
| def nan_count(self): | |||||
| """Get count of NAN.""" | |||||
| return self._nan_count | |||||
| @property | |||||
| def neg_inf_count(self): | |||||
| """Get count of negative INF.""" | |||||
| return self._neg_inf_count | |||||
| @property | |||||
| def pos_inf_count(self): | |||||
| """Get count of positive INF.""" | |||||
| return self._pos_inf_count | |||||
| def get_statistics_from_tensor(tensors): | |||||
| """ | |||||
| Calculates statistics data of tensor. | |||||
| Args: | |||||
| tensors (numpy.ndarray): An numpy.ndarray of tensor data. | |||||
| Returns: | |||||
| an instance of Statistics. | |||||
| """ | |||||
| ma_value = np.ma.masked_invalid(tensors) | |||||
| total, valid = tensors.size, ma_value.count() | |||||
| invalids = [] | |||||
| for isfn in np.isnan, np.isposinf, np.isneginf: | |||||
| if total - valid > sum(invalids): | |||||
| count = np.count_nonzero(isfn(tensors)) | |||||
| invalids.append(count) | |||||
| else: | |||||
| invalids.append(0) | |||||
| nan_count, pos_inf_count, neg_inf_count = invalids | |||||
| if not valid: | |||||
| logger.warning('There are no valid values in the tensors(size=%d, shape=%s)', total, tensors.shape) | |||||
| statistics = Statistics(max_value=0, | |||||
| min_value=0, | |||||
| avg_value=0, | |||||
| count=total, | |||||
| nan_count=nan_count, | |||||
| neg_inf_count=neg_inf_count, | |||||
| pos_inf_count=pos_inf_count) | |||||
| return statistics | |||||
| # BUG: max of a masked array with dtype np.float16 returns inf | |||||
| # See numpy issue#15077 | |||||
| if issubclass(tensors.dtype.type, np.floating): | |||||
| tensor_min = ma_value.min(fill_value=np.PINF) | |||||
| tensor_max = ma_value.max(fill_value=np.NINF) | |||||
| if tensor_min < F32_MIN or tensor_max > F32_MAX: | |||||
| logger.warning('Values(%f, %f) are too large, you may encounter some undefined ' | |||||
| 'behaviours hereafter.', tensor_min, tensor_max) | |||||
| else: | |||||
| tensor_min = ma_value.min() | |||||
| tensor_max = ma_value.max() | |||||
| tensor_sum = ma_value.sum(dtype=np.float64) | |||||
| statistics = Statistics(max_value=tensor_max, | |||||
| min_value=tensor_min, | |||||
| avg_value=tensor_sum / valid, | |||||
| count=total, | |||||
| nan_count=nan_count, | |||||
| neg_inf_count=neg_inf_count, | |||||
| pos_inf_count=pos_inf_count) | |||||
| return statistics | |||||
| def calc_original_buckets(np_value, stats): | def calc_original_buckets(np_value, stats): | ||||
| """ | """ | ||||
| Calculate buckets from tensor data. | Calculate buckets from tensor data. | ||||
| @@ -188,7 +74,7 @@ class TensorContainer: | |||||
| self._dims = tuple(tensor_message.dims) | self._dims = tuple(tensor_message.dims) | ||||
| self._data_type = tensor_message.data_type | self._data_type = tensor_message.data_type | ||||
| self._np_array = self.get_ndarray(tensor_message.float_data) | self._np_array = self.get_ndarray(tensor_message.float_data) | ||||
| self._stats = get_statistics_from_tensor(self._np_array) | |||||
| self._stats = TensorUtils.get_statistics_from_tensor(self._np_array) | |||||
| original_buckets = calc_original_buckets(self._np_array, self._stats) | original_buckets = calc_original_buckets(self._np_array, self._stats) | ||||
| self._count = sum(bucket.count for bucket in original_buckets) | self._count = sum(bucket.count for bucket in original_buckets) | ||||
| self._max = self._stats.max | self._max = self._stats.max | ||||
| @@ -19,97 +19,16 @@ import numpy as np | |||||
| from mindinsight.datavisual.utils.tools import to_int | from mindinsight.datavisual.utils.tools import to_int | ||||
| from mindinsight.utils.exceptions import ParamValueError, UrlDecodeError | from mindinsight.utils.exceptions import ParamValueError, UrlDecodeError | ||||
| from mindinsight.utils.tensor import TensorUtils | |||||
| from mindinsight.conf.constants import MAX_TENSOR_RESPONSE_DATA_SIZE | from mindinsight.conf.constants import MAX_TENSOR_RESPONSE_DATA_SIZE | ||||
| from mindinsight.datavisual.common.validation import Validation | from mindinsight.datavisual.common.validation import Validation | ||||
| from mindinsight.datavisual.common.exceptions import StepTensorDataNotInCacheError, TensorNotExistError | from mindinsight.datavisual.common.exceptions import StepTensorDataNotInCacheError, TensorNotExistError | ||||
| from mindinsight.datavisual.common.exceptions import ResponseDataExceedMaxValueError | from mindinsight.datavisual.common.exceptions import ResponseDataExceedMaxValueError | ||||
| from mindinsight.datavisual.data_transform.tensor_container import TensorContainer, get_statistics_from_tensor | |||||
| from mindinsight.datavisual.data_transform.tensor_container import TensorContainer | |||||
| from mindinsight.datavisual.processors.base_processor import BaseProcessor | from mindinsight.datavisual.processors.base_processor import BaseProcessor | ||||
| from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 | from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 | ||||
| def convert_array_from_str(dims, limit=0): | |||||
| """ | |||||
| Convert string of dims data to array. | |||||
| Args: | |||||
| dims (str): Specify dims of tensor. | |||||
| limit (int): The max flexible dimension count, default value is 0 which means that there is no limitation. | |||||
| Returns: | |||||
| list, a string like this: "[0, 0, :, :]" will convert to this value: [0, 0, None, None]. | |||||
| Raises: | |||||
| ParamValueError, If flexible dimensions exceed limit value. | |||||
| """ | |||||
| dims = dims.strip().lstrip('[').rstrip(']') | |||||
| dims_list = [] | |||||
| count = 0 | |||||
| for dim in dims.split(','): | |||||
| dim = dim.strip() | |||||
| if dim == ':': | |||||
| dims_list.append(None) | |||||
| count += 1 | |||||
| else: | |||||
| dims_list.append(to_int(dim, "dim")) | |||||
| if limit and count > limit: | |||||
| raise ParamValueError("Flexible dimensions cannot exceed limit value: {}, size: {}" | |||||
| .format(limit, count)) | |||||
| return dims_list | |||||
| def get_specific_dims_data(ndarray, dims, tensor_dims): | |||||
| """ | |||||
| Get specific dims data. | |||||
| Args: | |||||
| ndarray (numpy.ndarray): An ndarray of numpy. | |||||
| dims (list): A list of specific dims. | |||||
| tensor_dims (list): A list of tensor dims. | |||||
| Returns: | |||||
| numpy.ndarray, an ndarray of specific dims tensor data. | |||||
| Raises: | |||||
| ParamValueError, If the length of param dims is not equal to the length of tensor dims or | |||||
| the index of param dims out of range. | |||||
| """ | |||||
| if len(dims) != len(tensor_dims): | |||||
| raise ParamValueError("The length of param dims: {}, is not equal to the " | |||||
| "length of tensor dims: {}.".format(len(dims), len(tensor_dims))) | |||||
| indices = [] | |||||
| for k, d in enumerate(dims): | |||||
| if d is not None: | |||||
| if d >= tensor_dims[k]: | |||||
| raise ParamValueError("The index: {} of param dims out of range: {}.".format(d, tensor_dims[k])) | |||||
| indices.append(d) | |||||
| else: | |||||
| indices.append(slice(0, tensor_dims[k])) | |||||
| return ndarray[tuple(indices)] | |||||
| def get_statistics_dict(stats): | |||||
| """ | |||||
| Get statistics dict according to statistics value. | |||||
| Args: | |||||
| stats (Statistics): An instance of Statistics. | |||||
| Returns: | |||||
| dict, a dict including 'max', 'min', 'avg', 'count', 'nan_count', 'neg_inf_count', 'pos_inf_count'. | |||||
| """ | |||||
| statistics = { | |||||
| "max": float(stats.max), | |||||
| "min": float(stats.min), | |||||
| "avg": float(stats.avg), | |||||
| "count": stats.count, | |||||
| "nan_count": stats.nan_count, | |||||
| "neg_inf_count": stats.neg_inf_count, | |||||
| "pos_inf_count": stats.pos_inf_count | |||||
| } | |||||
| return statistics | |||||
| class TensorProcessor(BaseProcessor): | class TensorProcessor(BaseProcessor): | ||||
| """Tensor Processor.""" | """Tensor Processor.""" | ||||
| def get_tensors(self, train_ids, tags, step, dims, detail): | def get_tensors(self, train_ids, tags, step, dims, detail): | ||||
| @@ -130,22 +49,7 @@ class TensorProcessor(BaseProcessor): | |||||
| UrlDecodeError, If unquote train id error with strict mode. | UrlDecodeError, If unquote train id error with strict mode. | ||||
| """ | """ | ||||
| Validation.check_param_empty(train_id=train_ids, tag=tags) | Validation.check_param_empty(train_id=train_ids, tag=tags) | ||||
| if dims is not None: | |||||
| if not isinstance(dims, str): | |||||
| raise ParamValueError('The type of dims must be str, but got {}.'.format(type(dims))) | |||||
| dims = dims.strip() | |||||
| if not (dims.startswith('[') and dims.endswith(']')): | |||||
| raise ParamValueError('The value: {} of dims must be ' | |||||
| 'start with `[` and end with `]`.'.format(dims)) | |||||
| for dim in dims[1:-1].split(','): | |||||
| dim = dim.strip() | |||||
| if dim == ":": | |||||
| continue | |||||
| if dim.startswith('-'): | |||||
| dim = dim[1:] | |||||
| if not dim.isdigit(): | |||||
| raise ParamValueError('The value: {} of dims in the square brackets ' | |||||
| 'must be int or `:`.'.format(dims)) | |||||
| TensorUtils.validate_dims_format(dims) | |||||
| for index, train_id in enumerate(train_ids): | for index, train_id in enumerate(train_ids): | ||||
| try: | try: | ||||
| @@ -248,7 +152,7 @@ class TensorProcessor(BaseProcessor): | |||||
| "data_type": anf_ir_pb2.DataType.Name(value.data_type) | "data_type": anf_ir_pb2.DataType.Name(value.data_type) | ||||
| } | } | ||||
| if detail and detail == 'stats': | if detail and detail == 'stats': | ||||
| stats = get_statistics_dict(value.stats) | |||||
| stats = TensorUtils.get_statistics_dict(value.stats) | |||||
| value_dict.update({"statistics": stats}) | value_dict.update({"statistics": stats}) | ||||
| values.append({ | values.append({ | ||||
| @@ -295,14 +199,14 @@ class TensorProcessor(BaseProcessor): | |||||
| """ | """ | ||||
| values = [] | values = [] | ||||
| step_in_cache = False | step_in_cache = False | ||||
| dims = convert_array_from_str(dims, limit=2) | |||||
| dims = TensorUtils.convert_array_from_str_dims(dims, limit=2) | |||||
| for tensor in tensors: | for tensor in tensors: | ||||
| # This value is an instance of TensorContainer | # This value is an instance of TensorContainer | ||||
| value = tensor.value | value = tensor.value | ||||
| if step != tensor.step: | if step != tensor.step: | ||||
| continue | continue | ||||
| step_in_cache = True | step_in_cache = True | ||||
| res_data = get_specific_dims_data(value.ndarray, dims, list(value.dims)) | |||||
| res_data = TensorUtils.get_specific_dims_data(value.ndarray, dims, list(value.dims)) | |||||
| flatten_data = res_data.flatten().tolist() | flatten_data = res_data.flatten().tolist() | ||||
| if len(flatten_data) > MAX_TENSOR_RESPONSE_DATA_SIZE: | if len(flatten_data) > MAX_TENSOR_RESPONSE_DATA_SIZE: | ||||
| raise ResponseDataExceedMaxValueError("the size of response data: {} exceed max value: {}." | raise ResponseDataExceedMaxValueError("the size of response data: {} exceed max value: {}." | ||||
| @@ -328,7 +232,7 @@ class TensorProcessor(BaseProcessor): | |||||
| transfer_data[index] = float(data) | transfer_data[index] = float(data) | ||||
| return transfer_data | return transfer_data | ||||
| stats = get_statistics_from_tensor(res_data) | |||||
| stats = TensorUtils.get_statistics_from_tensor(res_data) | |||||
| if stats.nan_count + stats.neg_inf_count + stats.pos_inf_count > 0: | if stats.nan_count + stats.neg_inf_count + stats.pos_inf_count > 0: | ||||
| tensor_data = transfer(res_data) | tensor_data = transfer(res_data) | ||||
| else: | else: | ||||
| @@ -340,7 +244,7 @@ class TensorProcessor(BaseProcessor): | |||||
| "dims": value.dims, | "dims": value.dims, | ||||
| "data_type": anf_ir_pb2.DataType.Name(value.data_type), | "data_type": anf_ir_pb2.DataType.Name(value.data_type), | ||||
| "data": tensor_data, | "data": tensor_data, | ||||
| "statistics": get_statistics_dict(stats) | |||||
| "statistics": TensorUtils.get_statistics_dict(stats) | |||||
| } | } | ||||
| }) | }) | ||||
| break | break | ||||
| @@ -389,7 +293,7 @@ class TensorProcessor(BaseProcessor): | |||||
| "dims": value.dims, | "dims": value.dims, | ||||
| "data_type": anf_ir_pb2.DataType.Name(value.data_type), | "data_type": anf_ir_pb2.DataType.Name(value.data_type), | ||||
| "histogram_buckets": buckets, | "histogram_buckets": buckets, | ||||
| "statistics": get_statistics_dict(value.stats) | |||||
| "statistics": TensorUtils.get_statistics_dict(value.stats) | |||||
| } | } | ||||
| }) | }) | ||||
| @@ -80,6 +80,25 @@ def to_int(param, param_name): | |||||
| return param | return param | ||||
| def to_float(param, param_name): | |||||
| """ | |||||
| Transfer param to float type. | |||||
| Args: | |||||
| param (Any): A param transformed. | |||||
| param_name (str): Param name. | |||||
| Returns: | |||||
| float, value after transformed. | |||||
| """ | |||||
| try: | |||||
| param = float(param) | |||||
| except ValueError: | |||||
| raise exceptions.ParamTypeError(param_name, 'Float') | |||||
| return param | |||||
| def str_to_bool(param, param_name): | def str_to_bool(param, param_name): | ||||
| """ | """ | ||||
| Check param and transform it to bool. | Check param and transform it to bool. | ||||
| @@ -0,0 +1,20 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Debugger Module Introduction. | |||||
| This module provides Python APIs to retrieve the debugger info and control the training process. | |||||
| The APIs can help users to understand the training process and find the bugs in training script. | |||||
| """ | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,56 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Debugger error code and messages.""" | |||||
| from enum import Enum, unique | |||||
| from mindinsight.utils.constant import DebuggerErrors as DebuggerErrorCodes | |||||
| _PARAM_ERROR_MASK = 0b00001 << 7 | |||||
| _DEBUGGER_GRAPH_ERROR = 0b00010 << 7 | |||||
| _DEBUGGER_RUNNING_ERROR = 0b00011 << 7 | |||||
| @unique | |||||
| class DebuggerErrors(DebuggerErrorCodes): | |||||
| """Debugger error codes.""" | |||||
| PARAM_TYPE_ERROR = 0 | _PARAM_ERROR_MASK | |||||
| PARAM_VALUE_ERROR = 1 | _PARAM_ERROR_MASK | |||||
| NODE_NOT_IN_GRAPH_ERROR = 0 | _DEBUGGER_GRAPH_ERROR | |||||
| GRAPH_NOT_EXIST_ERROR = 1 | _DEBUGGER_GRAPH_ERROR | |||||
| CREATE_WATCHPOINT_ERROR = 0 | _DEBUGGER_RUNNING_ERROR | |||||
| UPDATE_WATCHPOINT_ERROR = 1 | _DEBUGGER_RUNNING_ERROR | |||||
| DELETE_WATCHPOINT_ERROR = 2 | _DEBUGGER_RUNNING_ERROR | |||||
| CONTINUE_ERROR = 3 | _DEBUGGER_RUNNING_ERROR | |||||
| PAUSE_ERROR = 4 | _DEBUGGER_RUNNING_ERROR | |||||
| COMPARE_TENSOR_ERROR = 5 | _DEBUGGER_RUNNING_ERROR | |||||
| @unique | |||||
| class DebuggerErrorMsg(Enum): | |||||
| """Debugger error messages.""" | |||||
| PARAM_TYPE_ERROR = "TypeError. {}" | |||||
| PARAM_VALUE_ERROR = "ValueError. {}" | |||||
| PARAM_MISSING_ERROR = "MissingError. {}" | |||||
| UNEXPECTED_EXCEPTION_ERROR = "Unexpected exception. {}" | |||||
| GRAPH_NOT_EXIST_ERROR = "The graph does not exist." | |||||
| CREATE_WATCHPOINT_ERROR = "Create watchpoint failed. {}" | |||||
| UPDATE_WATCHPOINT_ERROR = "Update watchpoint failed. {}" | |||||
| DELETE_WATCHPOINT_ERROR = "Delete watchpoint failed. {}" | |||||
| CONTINUE_ERROR = "Continue debugging failed. {}" | |||||
| PAUSE_ERROR = "Pause debugging failed. {}" | |||||
| @@ -0,0 +1,117 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Definition of error code and relative messages in debugger module.""" | |||||
| from mindinsight.utils.exceptions import MindInsightException | |||||
| from mindinsight.debugger.common.exceptions.error_code import DebuggerErrors, DebuggerErrorMsg | |||||
| class DebuggerParamTypeError(MindInsightException): | |||||
| """The parameter type error in debugger module.""" | |||||
| def __init__(self, msg): | |||||
| super(DebuggerParamTypeError, self).__init__( | |||||
| error=DebuggerErrors.PARAM_TYPE_ERROR, | |||||
| message=DebuggerErrorMsg.PARAM_TYPE_ERROR.value.format(msg) | |||||
| ) | |||||
| class DebuggerParamValueError(MindInsightException): | |||||
| """The parameter value error in debugger module.""" | |||||
| def __init__(self, msg): | |||||
| super(DebuggerParamValueError, self).__init__( | |||||
| error=DebuggerErrors.PARAM_VALUE_ERROR, | |||||
| message=DebuggerErrorMsg.PARAM_VALUE_ERROR.value.format(msg) | |||||
| ) | |||||
| class DebuggerCreateWatchPointError(MindInsightException): | |||||
| """The error about creating watch point.""" | |||||
| def __init__(self, msg): | |||||
| super(DebuggerCreateWatchPointError, self).__init__( | |||||
| error=DebuggerErrors.CREATE_WATCHPOINT_ERROR, | |||||
| message=DebuggerErrorMsg.CREATE_WATCHPOINT_ERROR.value.format(msg) | |||||
| ) | |||||
| class DebuggerUpdateWatchPointError(MindInsightException): | |||||
| """The error about updating watch point.""" | |||||
| def __init__(self, msg): | |||||
| super(DebuggerUpdateWatchPointError, self).__init__( | |||||
| error=DebuggerErrors.UPDATE_WATCHPOINT_ERROR, | |||||
| message=DebuggerErrorMsg.UPDATE_WATCHPOINT_ERROR.value.format(msg) | |||||
| ) | |||||
| class DebuggerDeleteWatchPointError(MindInsightException): | |||||
| """The error about deleting watch point.""" | |||||
| def __init__(self, msg): | |||||
| super(DebuggerDeleteWatchPointError, self).__init__( | |||||
| error=DebuggerErrors.DELETE_WATCHPOINT_ERROR, | |||||
| message=DebuggerErrorMsg.DELETE_WATCHPOINT_ERROR.value.format(msg) | |||||
| ) | |||||
| class DebuggerCompareTensorError(MindInsightException): | |||||
| """The error about comparing tensors.""" | |||||
| def __init__(self, msg): | |||||
| super(DebuggerCompareTensorError, self).__init__( | |||||
| error=DebuggerErrors.COMPARE_TENSOR_ERROR, | |||||
| message=DebuggerErrorMsg.COMPARE_TENSOR_ERROR.value.format(msg) | |||||
| ) | |||||
| class DebuggerContinueError(MindInsightException): | |||||
| """The error about continuing debugging.""" | |||||
| def __init__(self, msg): | |||||
| super(DebuggerContinueError, self).__init__( | |||||
| error=DebuggerErrors.CONTINUE_ERROR, | |||||
| message=DebuggerErrorMsg.CONTINUE_ERROR.value.format(msg) | |||||
| ) | |||||
| class DebuggerPauseError(MindInsightException): | |||||
| """The error about pausing debugging.""" | |||||
| def __init__(self, msg): | |||||
| super(DebuggerPauseError, self).__init__( | |||||
| error=DebuggerErrors.PAUSE_ERROR, | |||||
| message=DebuggerErrorMsg.PAUSE_ERROR.value.format(msg) | |||||
| ) | |||||
| class DebuggerNodeNotInGraphError(MindInsightException): | |||||
| """The node is not in the graph.""" | |||||
| def __init__(self, node_name, node_type=None): | |||||
| if node_type is not None: | |||||
| err_msg = f"Cannot find the node in graph by the given name. node name: {node_name}, type: {node_type}." | |||||
| else: | |||||
| err_msg = f"Cannot find the node in graph by the given name. node name: {node_name}." | |||||
| super(DebuggerNodeNotInGraphError, self).__init__( | |||||
| error=DebuggerErrors.NODE_NOT_IN_GRAPH_ERROR, | |||||
| message=err_msg | |||||
| ) | |||||
| class DebuggerGraphNotExistError(MindInsightException): | |||||
| """The graph does not exist.""" | |||||
| def __init__(self): | |||||
| super(DebuggerGraphNotExistError, self).__init__( | |||||
| error=DebuggerErrors.GRAPH_NOT_EXIST_ERROR, | |||||
| message=DebuggerErrorMsg.GRAPH_NOT_EXIST_ERROR.value | |||||
| ) | |||||
| @@ -0,0 +1,20 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Import mindinsight unified log module.""" | |||||
| from mindinsight.utils.log import setup_logger | |||||
| LOG_NAME = "debugger" | |||||
| LOG_MODULE = "debugger" | |||||
| logger = setup_logger(sub_module=LOG_MODULE, log_name=LOG_NAME) | |||||
| @@ -0,0 +1,168 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Define the utils.""" | |||||
| import enum | |||||
| from collections import namedtuple | |||||
| import numpy as np | |||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | |||||
| from mindinsight.debugger.common.log import logger as log | |||||
| from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply | |||||
| from mindinsight.debugger.stream_cache.debugger_graph import NodeTypeEnum | |||||
| # translate the MindSpore type to numpy type. | |||||
| NUMPY_TYPE_MAP = { | |||||
| 'DT_BOOL': np.bool, | |||||
| 'DT_INT8': np.int8, | |||||
| 'DT_INT16': np.int16, | |||||
| 'DT_INT32': np.int32, | |||||
| 'DT_INT64': np.int64, | |||||
| 'DT_UINT8': np.uint8, | |||||
| 'DT_UINT16': np.uint16, | |||||
| 'DT_UINT32': np.uint32, | |||||
| 'DT_UINT64': np.uint64, | |||||
| 'DT_FLOAT16': np.float16, | |||||
| 'DT_FLOAT32': np.float32, | |||||
| 'DT_FLOAT64': np.float64, | |||||
| 'DT_STRING': np.str | |||||
| } | |||||
| @enum.unique | |||||
| class ReplyStates(enum.Enum): | |||||
| """Define the status of reply.""" | |||||
| SUCCESS = 0 | |||||
| FAILED = -1 | |||||
| @enum.unique | |||||
| class ServerStatus(enum.Enum): | |||||
| """The status of debugger server.""" | |||||
| PENDING = 'pending' # no client session has been connected | |||||
| RECEIVE_GRAPH = 'receive graph' # the client session has sent the graph | |||||
| WAITING = 'waiting' # the client session is ready | |||||
| RUNNING = 'running' # the client session is running a script | |||||
| @enum.unique | |||||
| class Streams(enum.Enum): | |||||
| """Define the enable streams to be deal with.""" | |||||
| COMMAND = "command" | |||||
| DATA = "data" | |||||
| METADATA = "metadata" | |||||
| GRAPH = 'node' | |||||
| TENSOR = 'tensor' | |||||
| WATCHPOINT = 'watchpoint' | |||||
| WATCHPOINT_HIT = 'watchpoint_hit' | |||||
| NodeBasicInfo = namedtuple('node_basic_info', ['name', 'full_name', 'type']) | |||||
| def get_ack_reply(state=0): | |||||
| """The the ack EventReply.""" | |||||
| reply = EventReply() | |||||
| state_mapping = { | |||||
| 0: EventReply.Status.OK, | |||||
| 1: EventReply.Status.FAILED, | |||||
| 2: EventReply.Status.PENDING | |||||
| } | |||||
| reply.status = state_mapping[state] | |||||
| return reply | |||||
| def wrap_reply_response(error_code=None, error_message=None): | |||||
| """ | |||||
| Wrap reply response. | |||||
| Args: | |||||
| error_code (str): Error code. Default: None. | |||||
| error_message (str): Error message. Default: None. | |||||
| Returns: | |||||
| str, serialized response. | |||||
| """ | |||||
| if error_code is None: | |||||
| reply = {'state': ReplyStates.SUCCESS.value} | |||||
| else: | |||||
| reply = { | |||||
| 'state': ReplyStates.FAILED.value, | |||||
| 'error_code': error_code, | |||||
| 'error_message': error_message | |||||
| } | |||||
| return reply | |||||
| def create_view_event_from_tensor_history(tensor_history): | |||||
| """ | |||||
| Create view event reply according to tensor names. | |||||
| Args: | |||||
| tensor_history (list[dict]): The list of tensor history. Each element has keys: | |||||
| `name`, `node_type`. | |||||
| Returns: | |||||
| EventReply, the event reply with view cmd. | |||||
| """ | |||||
| view_event = get_ack_reply() | |||||
| for tensor_info in tensor_history: | |||||
| node_type = tensor_info.get('node_type') | |||||
| if node_type == NodeTypeEnum.CONST.value: | |||||
| continue | |||||
| truncate_tag = tensor_info.get('node_type') == NodeTypeEnum.PARAMETER.value | |||||
| tensor_name = tensor_info.get('full_name', '') | |||||
| # create view command | |||||
| ms_tensor = view_event.view_cmd.tensors.add() | |||||
| ms_tensor.node_name, ms_tensor.slot = tensor_name.rsplit(':', 1) | |||||
| ms_tensor.truncate = truncate_tag | |||||
| ms_tensor.iter = 'prev' if tensor_info.get('iter') else '' | |||||
| return view_event | |||||
| def is_scope_type(node_type): | |||||
| """Judge whether the type is scope type.""" | |||||
| scope_types = [NodeTypeEnum.NAME_SCOPE.value, NodeTypeEnum.AGGREGATION_SCOPE.value] | |||||
| return node_type in scope_types | |||||
| def str_to_slice_or_int(input_str): | |||||
| """ | |||||
| Translate param from string to slice or int. | |||||
| Args: | |||||
| input_str (str): The string to be translated. | |||||
| Returns: | |||||
| Union[int, slice], the transformed param. | |||||
| """ | |||||
| try: | |||||
| if ':' in input_str: | |||||
| ret = slice(*map(lambda x: int(x.strip()) if x.strip() else None, input_str.split(':'))) | |||||
| else: | |||||
| ret = int(input_str) | |||||
| except ValueError as err: | |||||
| log.error("Failed to create slice from %s", input_str) | |||||
| log.exception(err) | |||||
| raise DebuggerParamValueError("Invalid shape.") | |||||
| return ret | |||||
| @@ -0,0 +1,154 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Implement the debugger data cache manager.""" | |||||
| import sys | |||||
| from mindinsight.debugger.common.log import logger as log | |||||
| from mindinsight.debugger.common.utils import Streams | |||||
| from mindinsight.debugger.stream_handler import EventHandler, MetadataHandler, GraphHandler, \ | |||||
| TensorHandler, WatchpointHandler, WatchpointHitHandler | |||||
| STREAM_HANDLER_MAP = { | |||||
| Streams.COMMAND.value: EventHandler, | |||||
| Streams.DATA.value: EventHandler, | |||||
| Streams.METADATA.value: MetadataHandler, | |||||
| Streams.GRAPH.value: GraphHandler, | |||||
| Streams.TENSOR.value: TensorHandler, | |||||
| Streams.WATCHPOINT.value: WatchpointHandler, | |||||
| Streams.WATCHPOINT_HIT.value: WatchpointHitHandler | |||||
| } | |||||
| class DebuggerCache: | |||||
| """The debugger data cache manager.""" | |||||
| def __init__(self): | |||||
| self._stream_handler = {} | |||||
| def initialize(self): | |||||
| """Initialize the stream handlers.""" | |||||
| self._stream_handler = {} | |||||
| for stream in Streams: | |||||
| mode = stream.value | |||||
| stream_handler = STREAM_HANDLER_MAP.get(mode) | |||||
| self._stream_handler[mode] = stream_handler() | |||||
| def clean(self): | |||||
| """Clean cache for all stream.""" | |||||
| for _, stream_handler in self._stream_handler.items(): | |||||
| stream_handler.clean() | |||||
| def get_stream_handler(self, mode): | |||||
| """ | |||||
| Get the stream handler object. | |||||
| Args: | |||||
| mode (Streams): The type of stream handler. | |||||
| Returns: | |||||
| StreamHandlerBase, the stream handler object. | |||||
| """ | |||||
| return self._stream_handler.get(mode.value) | |||||
| def _get(self, mode, pos): | |||||
| """ | |||||
| Get updated data or command from cache. | |||||
| Args: | |||||
| mode (Streams): The type of info. `Streams.DATA` or `Streams.COMMAND`. | |||||
| pos (int): The index of info. | |||||
| Returns: | |||||
| object, the pos-th message about `mode` type of info. | |||||
| """ | |||||
| stream_handler = self.get_stream_handler(mode) | |||||
| return stream_handler.get(pos) | |||||
| def _put(self, mode, value): | |||||
| """ | |||||
| Set updated data or command from cache. | |||||
| Args: | |||||
| mode (Streams): The type of info. `Streams.DATA` or `Streams.COMMAND`. | |||||
| value (object): The info to be record in cache. | |||||
| """ | |||||
| stream_handler = self.get_stream_handler(mode) | |||||
| return stream_handler.put(value) | |||||
| def get_command(self, pos): | |||||
| """ | |||||
| Get the pos-th command in command stream. | |||||
| Args: | |||||
| pos (int): The index of command. | |||||
| Returns: | |||||
| int, the position of next message. | |||||
| EventReply, the command object. | |||||
| """ | |||||
| content = self._get(Streams.COMMAND, pos) | |||||
| next_pos = content.get('metadata').get('pos') | |||||
| reply = content.get('cmd') | |||||
| return next_pos, reply | |||||
| def put_command(self, cmd): | |||||
| """ | |||||
| Set command to command stream. | |||||
| Args: | |||||
| cmd (EventReply): The command EventReply. | |||||
| """ | |||||
| log.debug("Set command %s", cmd) | |||||
| return self._put(Streams.COMMAND, {'cmd': cmd}) | |||||
| def has_command(self, pos): | |||||
| """Judge if the number of command is no less than `pos`.""" | |||||
| event = self.get_stream_handler(Streams.COMMAND).has_pos(pos) | |||||
| return event | |||||
| def clean_command(self): | |||||
| """Clean command queue.""" | |||||
| self.get_stream_handler(Streams.COMMAND).clean() | |||||
| log.debug("Clean command.") | |||||
| def clean_data(self): | |||||
| """Clean command queue.""" | |||||
| self.get_stream_handler(Streams.DATA).clean() | |||||
| log.debug("Clean data queue.") | |||||
| def get_data(self, pos): | |||||
| """ | |||||
| Get updated data from data stream. | |||||
| Args: | |||||
| pos (int): The index of data. | |||||
| Returns: | |||||
| object, updated data_value. | |||||
| """ | |||||
| return self._get(Streams.DATA, pos) | |||||
| def put_data(self, value): | |||||
| """ | |||||
| Set updated data to data stream. | |||||
| Args: | |||||
| value (dict): The updated data. | |||||
| """ | |||||
| log.debug("Set <%d> bytes data", sys.getsizeof(value)) | |||||
| return self._put(Streams.DATA, value) | |||||
| @@ -0,0 +1,309 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Implement the debugger grpc server.""" | |||||
| from functools import wraps | |||||
| from mindinsight.debugger.common.log import logger as log | |||||
| from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ | |||||
| create_view_event_from_tensor_history, Streams | |||||
| from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base | |||||
| from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto | |||||
| def debugger_wrap(func): | |||||
| """Wrapper for catch exception.""" | |||||
| @wraps(func) | |||||
| def record_log(*args, **kwargs): | |||||
| try: | |||||
| return func(*args, **kwargs) | |||||
| except Exception as err: | |||||
| log.exception(err) | |||||
| raise err | |||||
| return record_log | |||||
| class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||||
| """The grpc server used to interactive with grpc client.""" | |||||
| def __init__(self, cache_store): | |||||
| """ | |||||
| Initialize. | |||||
| Args: | |||||
| cache_store (DebuggerCache): Debugger cache store. | |||||
| """ | |||||
| cache_store.initialize() | |||||
| self._cache_store = cache_store | |||||
| self._pos = None | |||||
| self._status = None | |||||
| self._view_event = None | |||||
| self._view_round = None | |||||
| self._continue_steps = None | |||||
| self.init() | |||||
| def init(self): | |||||
| """Init debugger grpc server.""" | |||||
| self._pos = '0' | |||||
| self._status = ServerStatus.PENDING | |||||
| self._view_event = None | |||||
| self._view_round = True | |||||
| self._continue_steps = 0 | |||||
| self._cache_store.clean() | |||||
| @debugger_wrap | |||||
| def WaitCMD(self, request, context): | |||||
| """Wait for a command in DebuggerCache.""" | |||||
| # check if graph have already received. | |||||
| log.info("Received WaitCMD at %s-th step.", request.cur_step) | |||||
| if self._status == ServerStatus.PENDING: | |||||
| log.warning("No graph received before WaitCMD.") | |||||
| reply = get_ack_reply(1) | |||||
| return reply | |||||
| # send graph if has not been sent before | |||||
| self._pre_process(request) | |||||
| # deal with old command | |||||
| reply = self._deal_with_old_command() | |||||
| if reply: | |||||
| log.info("Reply to WaitCMD with old command: %s", reply) | |||||
| return reply | |||||
| # send view cmd | |||||
| if self._view_round and self._view_event: | |||||
| self._view_round = False | |||||
| reply = self._view_event | |||||
| log.debug("Send ViewCMD.") | |||||
| # continue multiple steps training | |||||
| elif self._continue_steps != 0: | |||||
| reply = get_ack_reply() | |||||
| reply.run_cmd.run_steps = 1 | |||||
| reply.run_cmd.run_level = 'step' | |||||
| self._continue_steps = self._continue_steps - 1 if self._continue_steps > 0 else -1 | |||||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() | |||||
| log.debug("Send RunCMD. Clean watchpoint hit.") | |||||
| # wait for command | |||||
| else: | |||||
| reply = self._wait_for_next_command() | |||||
| if reply is None: | |||||
| reply = get_ack_reply(1) | |||||
| log.warning("Failed to get command event.") | |||||
| else: | |||||
| log.info("Reply to WaitCMD: %s", reply) | |||||
| return reply | |||||
| def _pre_process(self, request): | |||||
| """Send graph and metadata when WaitCMD first called.""" | |||||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | |||||
| if self._status == ServerStatus.RECEIVE_GRAPH: | |||||
| self._status = ServerStatus.WAITING | |||||
| metadata_stream.state = 'waiting' | |||||
| metadata = metadata_stream.get() | |||||
| self._cache_store.clean_command() | |||||
| res = self._cache_store.get_stream_handler(Streams.GRAPH).get() | |||||
| res.update(metadata) | |||||
| self._cache_store.put_data(res) | |||||
| log.info("Put graph into data queue.") | |||||
| if metadata_stream.step < request.cur_step or metadata_stream.full_name != request.cur_node: | |||||
| # clean tensor cache and DataQueue at the beginning of each step | |||||
| self._update_metadata(metadata_stream, request) | |||||
| def _update_metadata(self, metadata_stream, metadata_proto): | |||||
| """Update metadata.""" | |||||
| # reset view round and clean cache data | |||||
| self._view_round = True | |||||
| if metadata_stream.step < metadata_proto.cur_step: | |||||
| self._cache_store.clean_data() | |||||
| self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors( | |||||
| metadata_proto.cur_step) | |||||
| # put new metadata into cache | |||||
| metadata_stream.put(metadata_proto) | |||||
| cur_node = self._cache_store.get_stream_handler(Streams.GRAPH).get_node_name_by_full_name( | |||||
| metadata_proto.cur_node) if metadata_proto.cur_node else '' | |||||
| metadata_stream.node_name = cur_node | |||||
| metadata = metadata_stream.get() | |||||
| self._cache_store.put_data(metadata) | |||||
| log.info("Put new metadata into data queue.") | |||||
| def _deal_with_old_command(self): | |||||
| """Deal with old command.""" | |||||
| event = None | |||||
| while self._cache_store.has_command(self._pos) and event is None: | |||||
| event = self._get_next_command() | |||||
| log.debug("Deal with old %s-th command:\n%s.", self._pos, event) | |||||
| return event | |||||
| def _wait_for_next_command(self): | |||||
| """ | |||||
| Wait for next command. | |||||
| Returns: | |||||
| EventReply, the command event. | |||||
| """ | |||||
| log.info("Start to wait for command.") | |||||
| self._cache_store.get_stream_handler(Streams.METADATA).state = 'waiting' | |||||
| self._cache_store.put_data({'metadata': {'state': 'waiting'}}) | |||||
| event = None | |||||
| while event is None and self._status == ServerStatus.WAITING: | |||||
| log.debug("Wait for %s-th command", self._pos) | |||||
| event = self._get_next_command() | |||||
| return event | |||||
| def _get_next_command(self): | |||||
| """Get next command.""" | |||||
| self._pos, event = self._cache_store.get_command(self._pos) | |||||
| log.debug("Received event :%s", event) | |||||
| if event is None: | |||||
| return event | |||||
| if isinstance(event, dict) and event.get('reset'): | |||||
| self._set_view_event(event) | |||||
| event = None | |||||
| elif event.HasField('run_cmd'): | |||||
| event = self._deal_with_run_cmd(event) | |||||
| elif event.HasField('view_cmd'): | |||||
| self._view_round = False | |||||
| elif event.HasField('exit'): | |||||
| self._cache_store.clean() | |||||
| log.info("Clean cache for exit cmd.") | |||||
| return event | |||||
| def _deal_with_run_cmd(self, event): | |||||
| """Deal with run cmd.""" | |||||
| run_cmd = event.run_cmd | |||||
| # receive step command | |||||
| if run_cmd.run_level == 'step': | |||||
| # receive pause cmd | |||||
| if run_cmd.run_steps == 0: | |||||
| log.debug("Pause training and wait for next command.") | |||||
| self._continue_steps = 0 | |||||
| return None | |||||
| # receive step cmd | |||||
| self._continue_steps = run_cmd.run_steps - 1 | |||||
| event.run_cmd.run_steps = 1 | |||||
| self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() | |||||
| log.debug("Receive RunCMD. Clean watchpoint hit cache.") | |||||
| return event | |||||
| def _set_view_event(self, event): | |||||
| """Create view event for view cmd.""" | |||||
| # the first tensor in view cmd is always the output | |||||
| node_name = event.get('node_name') | |||||
| tensor_history = event.get('tensor_history') | |||||
| if not node_name or not tensor_history: | |||||
| self._view_event = None | |||||
| log.info("Reset view command to None.") | |||||
| else: | |||||
| # create view event and set | |||||
| self._view_event = create_view_event_from_tensor_history(tensor_history) | |||||
| log.info("Reset view command to %s.", node_name) | |||||
| @debugger_wrap | |||||
| def SendMetadata(self, request, context): | |||||
| """Send metadata into DebuggerCache.""" | |||||
| log.info("Received Metadata.") | |||||
| if self._status != ServerStatus.PENDING: | |||||
| log.info("Re-initialize cache store when new session comes.") | |||||
| self.init() | |||||
| client_ip = context.peer().split(':', 1)[-1] | |||||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | |||||
| metadata_stream.put(request) | |||||
| metadata_stream.client_ip = client_ip | |||||
| metadata = metadata_stream.get() | |||||
| # put metadata into data queue | |||||
| self._cache_store.put_data(metadata) | |||||
| log.info("Put new metadata to DataQueue.") | |||||
| reply = get_ack_reply() | |||||
| log.info("Send the reply to %s.", client_ip) | |||||
| return reply | |||||
| @debugger_wrap | |||||
| def SendGraph(self, request_iterator, context): | |||||
| """Send graph into DebuggerCache.""" | |||||
| log.info("Received graph.") | |||||
| serial_graph = b"" | |||||
| for chunk in request_iterator: | |||||
| serial_graph += chunk.buffer | |||||
| graph = GraphProto.FromString(serial_graph) | |||||
| log.debug("Deserialize the graph. Receive %s nodes", len(graph.node)) | |||||
| self._cache_store.get_stream_handler(Streams.GRAPH).put(graph) | |||||
| self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals) | |||||
| self._status = ServerStatus.RECEIVE_GRAPH | |||||
| reply = get_ack_reply() | |||||
| log.info("Send the reply for graph.") | |||||
| return reply | |||||
| @debugger_wrap | |||||
| def SendTensors(self, request_iterator, context): | |||||
| """Send tensors into DebuggerCache.""" | |||||
| log.info("Received tensor.") | |||||
| tensor_construct = [] | |||||
| tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR) | |||||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | |||||
| tensor_names = [] | |||||
| step = metadata_stream.step | |||||
| for tensor in request_iterator: | |||||
| tensor_construct.append(tensor) | |||||
| if tensor.finished: | |||||
| tensor_stream.put({'step': step, 'tensor_protos': tensor_construct}) | |||||
| tensor_construct = [] | |||||
| tensor_names.append(':'.join([tensor.node_name, tensor.slot])) | |||||
| continue | |||||
| # send back tensor finished flag when all waiting tensor has value. | |||||
| tensor_history = tensor_stream.get_tensor_history(tensor_names) | |||||
| self._add_node_name_for_tensor_history(tensor_history) | |||||
| metadata = metadata_stream.get() | |||||
| tensor_history.update(metadata) | |||||
| self._cache_store.put_data({}) # reply to the listening request | |||||
| self._cache_store.put_data(tensor_history) | |||||
| log.info("Send updated tensor history to data queue.") | |||||
| reply = get_ack_reply() | |||||
| return reply | |||||
| def _add_node_name_for_tensor_history(self, tensor_history): | |||||
| """Add node name for tensor history.""" | |||||
| graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) | |||||
| for tensor_info in tensor_history.get('tensor_history'): | |||||
| if tensor_info: | |||||
| full_name, slot = tensor_info.get('full_name', '').rsplit(':', 1) | |||||
| node_name = graph_stream.get_node_name_by_full_name(full_name) | |||||
| tensor_info['name'] = node_name + ':' + slot | |||||
| @debugger_wrap | |||||
| def SendWatchpointHits(self, request_iterator, context): | |||||
| """Send watchpoint hits info DebuggerCache.""" | |||||
| log.info("Received WatchpointHits. Left steps %d change to 0.", self._continue_steps) | |||||
| self._continue_steps = 0 | |||||
| self._view_event = None | |||||
| watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) | |||||
| watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT) | |||||
| graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) | |||||
| for watchpoint_hit_proto in request_iterator: | |||||
| watchpoint_hit = { | |||||
| 'tensor_proto': watchpoint_hit_proto.tensor, | |||||
| 'watchpoint': watchpoint_stream.get_watchpoint_by_id(watchpoint_hit_proto.id), | |||||
| 'node_name': graph_stream.get_node_name_by_full_name( | |||||
| watchpoint_hit_proto.tensor.node_name) | |||||
| } | |||||
| watchpoint_hit_stream.put(watchpoint_hit) | |||||
| watchpoint_hits_info = watchpoint_hit_stream.get() | |||||
| self._cache_store.put_data(watchpoint_hits_info) | |||||
| log.info("Send the watchpoint hits to DataQueue.\nSend the reply.") | |||||
| reply = get_ack_reply() | |||||
| return reply | |||||
| @@ -0,0 +1,752 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Implement the debugger server.""" | |||||
| import signal | |||||
| from concurrent import futures | |||||
| from threading import Thread | |||||
| import grpc | |||||
| from mindinsight.conf import settings | |||||
| from mindinsight.datavisual.data_transform.graph import NodeTypeEnum | |||||
| from mindinsight.datavisual.utils.tools import to_float | |||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | |||||
| DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \ | |||||
| DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, DebuggerCompareTensorError | |||||
| from mindinsight.debugger.common.log import logger as log | |||||
| from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ | |||||
| create_view_event_from_tensor_history, Streams, is_scope_type, NodeBasicInfo, \ | |||||
| str_to_slice_or_int | |||||
| from mindinsight.debugger.debugger_cache import DebuggerCache | |||||
| from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer | |||||
| from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base | |||||
| from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD | |||||
| from mindinsight.utils.exceptions import MindInsightException | |||||
| class DebuggerServer: | |||||
| """The server manager of debugger.""" | |||||
| def __init__(self, grpc_port=None): | |||||
| self.grpc_port = grpc_port | |||||
| self.cache_store = DebuggerCache() | |||||
| self.grpc_server = DebuggerGrpcServer(self.cache_store) | |||||
| self.grpc_server_manager = None | |||||
| self.back_server = None | |||||
| self._watch_point_id = 0 | |||||
| def start(self): | |||||
| """Start server.""" | |||||
| grpc_port = self.grpc_port if self.grpc_port else "50051" | |||||
| host = settings.HOST if hasattr(settings, 'HOST') else '[::]' | |||||
| hostname = "{}:{}".format(host, grpc_port) | |||||
| # initialize a grpc server | |||||
| grpc_server_manager = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | |||||
| grpc_server_base.add_EventListenerServicer_to_server(self.grpc_server, grpc_server_manager) | |||||
| grpc_server_manager.add_insecure_port(hostname) | |||||
| grpc_server_manager.start() | |||||
| my_server_thread = Thread(target=grpc_server_manager.wait_for_termination) | |||||
| # start grpc server | |||||
| my_server_thread.start() | |||||
| self.back_server = my_server_thread | |||||
| self.grpc_server_manager = grpc_server_manager | |||||
| # register stop server handler | |||||
| signal.signal(signal.SIGINT, self._stop_handler) | |||||
| log.info("Start grpc server %s", hostname) | |||||
| def _stop_handler(self, signum, frame): | |||||
| """Register stop server handler.""" | |||||
| self.stop() | |||||
| log.debug("Deal with stop signal: %s, %s", signum, frame) | |||||
| def stop(self): | |||||
| """Stop debugger server.""" | |||||
| self.grpc_server_manager.stop(grace=None) | |||||
| self.back_server.join() | |||||
| log.info("Stop debugger server.") | |||||
| def poll_data(self, pos): | |||||
| """ | |||||
| Get the pos-th data from DebuggerCache. | |||||
| Args: | |||||
| pos (int): The index of data. | |||||
| Returns: | |||||
| dict, the data to be updated. | |||||
| """ | |||||
| if not isinstance(pos, str): | |||||
| log.error("Pos should be string. Received: %s", pos) | |||||
| raise DebuggerParamValueError("Pos should be string.") | |||||
| reply = self.cache_store.get_data(pos) | |||||
| return reply | |||||
| def search(self, name, watch_point_id): | |||||
| """Search for single node in graph.""" | |||||
| log.info("receive search request for node:%s, in watchpoint:%d", name, watch_point_id) | |||||
| graph = self.cache_store.get_stream_handler(Streams.GRAPH).search_nodes(name) | |||||
| self.cache_store.get_stream_handler(Streams.WATCHPOINT).set_watch_nodes( | |||||
| graph, watch_point_id) | |||||
| return graph | |||||
| def tensor_comparisons(self, name, shape, detail='data', tolerance='0'): | |||||
| """ | |||||
| Get tensor comparisons data for given name, detail, shape and tolerance. | |||||
| Args: | |||||
| name (str): The name of tensor for ui. | |||||
| detail (str): Specify which data to query. Current available value is 'data' which means | |||||
| concrete tensor data. Histogram or unique count can be supported in the future. | |||||
| shape (str): Specify concrete dimensions of shape. | |||||
| tolerance (str): Specify tolerance of difference between current step tensor and previous | |||||
| step tensor. Default value is 0. | |||||
| Raises: | |||||
| DebuggerParamValueError, If node type is not parameter or value of detail is not support. | |||||
| DebuggerCompareTensorError, If MindSpore is not in waiting state. | |||||
| Returns: | |||||
| dict, the retrieved data. | |||||
| """ | |||||
| if self.cache_store.get_stream_handler( | |||||
| Streams.METADATA).state != ServerStatus.WAITING.value: | |||||
| log.error("Failed to compare tensors as the MindSpore is not in waiting state.") | |||||
| raise DebuggerCompareTensorError( | |||||
| "Failed to compare tensors as the MindSpore is not in waiting state." | |||||
| ) | |||||
| self.validate_tensor_param(name, detail) | |||||
| parsed_shape = self.parse_shape(shape) | |||||
| node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name) | |||||
| tolerance = to_float(tolerance, 'tolerance') | |||||
| tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) | |||||
| if detail == 'data': | |||||
| if node_type == NodeTypeEnum.PARAMETER.value: | |||||
| reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance) | |||||
| else: | |||||
| raise DebuggerParamValueError("The node type must be parameter, but got {}.".format(node_type)) | |||||
| else: | |||||
| raise DebuggerParamValueError("The value of detail: {} is not support.".format(detail)) | |||||
| return reply | |||||
| def retrieve(self, mode, filter_condition=None): | |||||
| """ | |||||
| Retrieve data according to mode and params. | |||||
| Args: | |||||
| mode (str): The type of info message. | |||||
| filter_condition (dict): The filter condition. | |||||
| Returns: | |||||
| dict, the retrieved data. | |||||
| """ | |||||
| log.info("receive retrieve request for mode:%s\n, filter_condition: %s", mode, | |||||
| filter_condition) | |||||
| # validate watchpoint_id | |||||
| mode_mapping = { | |||||
| 'all': self._retrieve_all, | |||||
| 'node': self._retrieve_node, | |||||
| 'watchpoint': self._retrieve_watchpoint, | |||||
| 'watchpoint_hit': self._retrieve_watchpoint_hit | |||||
| } | |||||
| # validate param <mode> | |||||
| if mode not in mode_mapping.keys(): | |||||
| log.error("Invalid param <mode>. <mode> should be in ['all', 'node', 'watchpoint', " | |||||
| "'watchpoint_hit', 'tensor'], but got %s.", mode_mapping) | |||||
| raise DebuggerParamTypeError("Invalid mode.") | |||||
| filter_condition = {} if filter_condition is None else filter_condition | |||||
| self._watch_point_id = filter_condition.get('watch_point_id', 0) | |||||
| reply = mode_mapping[mode](filter_condition) | |||||
| return reply | |||||
| def _retrieve_all(self, filter_condition=None): | |||||
| """Retrieve metadata, root graph and watchpoint list.""" | |||||
| if filter_condition: | |||||
| log.error("No filter condition required for retrieve all request.") | |||||
| raise DebuggerParamTypeError("filter_condition should be empty.") | |||||
| result = {} | |||||
| self.cache_store.clean_data() | |||||
| log.info("Clean data queue cache when retrieve all request.") | |||||
| self.cache_store.put_command({'reset': True}) | |||||
| for stream in [Streams.METADATA, Streams.GRAPH, Streams.WATCHPOINT]: | |||||
| sub_res = self.cache_store.get_stream_handler(stream).get() | |||||
| result.update(sub_res) | |||||
| return result | |||||
| def _retrieve_node(self, filter_condition): | |||||
| """ | |||||
| Retrieve node info. | |||||
| Args: | |||||
| filter_condition (dict): Filter condition. | |||||
| - name (str): The name of single node. | |||||
| - watch_point_id (int): The id of watchpoint. | |||||
| - single_node (bool): If False, return the sub-layer of single node. If True, return | |||||
| the node list from root node to single node. | |||||
| Returns: | |||||
| dict, the node info. | |||||
| """ | |||||
| log.info("Retrieve node %s.", filter_condition) | |||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||||
| # validate parameters | |||||
| node_name = filter_condition.get('name') | |||||
| if not node_name: | |||||
| node_type = NodeTypeEnum.NAME_SCOPE.value | |||||
| else: | |||||
| node_type = graph_stream.get_node_type(node_name) | |||||
| filter_condition['node_type'] = node_type | |||||
| filter_condition['single_node'] = bool(filter_condition.get('single_node')) | |||||
| # get graph for scope node | |||||
| if is_scope_type(node_type): | |||||
| reply = self._get_nodes_info(filter_condition) | |||||
| # get tensor history for leaf node | |||||
| else: | |||||
| reply = self._get_tensor_history(node_name) | |||||
| if filter_condition.get('single_node'): | |||||
| graph = self._get_nodes_info(filter_condition) | |||||
| reply.update(graph) | |||||
| return reply | |||||
| def _get_nodes_info(self, filter_condition): | |||||
| """ | |||||
| Get nodes info. | |||||
| Args: | |||||
| filter_condition (dict): The filter condition. | |||||
| - name (str): The node name. | |||||
| - single_node (bool): If False, return the sub-layer of single node. If True, return | |||||
| the node list from root node to single node. | |||||
| - watch_point_id (int): The id of watchpoint. | |||||
| Returns: | |||||
| dict, reply with graph. | |||||
| """ | |||||
| # get graph | |||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||||
| reply = graph_stream.get(filter_condition) | |||||
| graph = reply.get('graph') | |||||
| # add watched label | |||||
| self.cache_store.get_stream_handler(Streams.WATCHPOINT).set_watch_nodes( | |||||
| graph, self._watch_point_id) | |||||
| return reply | |||||
| def retrieve_tensor_history(self, node_name): | |||||
| """ | |||||
| Retrieve tensor history for leaf node. | |||||
| Args: | |||||
| node_name (str): The name of leaf node. | |||||
| Returns: | |||||
| dict, the tensor history and metadata. | |||||
| """ | |||||
| log.info("Retrieve tensor history for node: %s.", node_name) | |||||
| self._validate_leaf_name(node_name) | |||||
| res = self._get_tensor_history(node_name) | |||||
| return res | |||||
| def _validate_leaf_name(self, node_name): | |||||
| """Validate if the node is a leaf node.""" | |||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||||
| node_type = graph_stream.get_node_type(node_name) | |||||
| if is_scope_type(node_type): | |||||
| log.error("Scope type node has no tensor history.") | |||||
| raise DebuggerParamValueError("Invalid leaf node name.") | |||||
| def _get_tensor_history(self, node_name): | |||||
| """ | |||||
| Get tensor history for single node. | |||||
| Args: | |||||
| node_name (str): The name of leaf node. | |||||
| Returns: | |||||
| dict, the tensor history and metadata. | |||||
| """ | |||||
| # get basic tensor history | |||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||||
| tensor_history = graph_stream.get_tensor_history(node_name) | |||||
| # set the view event | |||||
| self.cache_store.put_command( | |||||
| {'reset': True, | |||||
| 'node_name': node_name, | |||||
| 'tensor_history': tensor_history.get('tensor_history')}) | |||||
| # add tensor value for tensor history | |||||
| self._add_tensor_value_for_tensor_history(tensor_history) | |||||
| # add hit label for tensor history | |||||
| watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) | |||||
| watchpoint_hit_stream.update_tensor_history(tensor_history) | |||||
| # add metadata | |||||
| metadata = self.cache_store.get_stream_handler(Streams.METADATA).get() | |||||
| tensor_history.update(metadata) | |||||
| return tensor_history | |||||
| def _add_tensor_value_for_tensor_history(self, tensor_history): | |||||
| """ | |||||
| Add tensor value for_tensor_history and send ViewCMD if tensor value missed. | |||||
| Args: | |||||
| tensor_history (list[dict]): A list of tensor info, including name and type. | |||||
| Returns: | |||||
| dict, the tensor info. | |||||
| """ | |||||
| tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) | |||||
| missed_tensors = tensor_stream.update_tensor_history(tensor_history) | |||||
| if missed_tensors: | |||||
| view_cmd = create_view_event_from_tensor_history(missed_tensors) | |||||
| self.cache_store.put_command(view_cmd) | |||||
| log.debug("Send view cmd.") | |||||
| def retrieve_tensor_value(self, name, detail, shape): | |||||
| """Retrieve the tensor value.""" | |||||
| log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape) | |||||
| self.validate_tensor_param(name, detail) | |||||
| parsed_shape = self.parse_shape(shape) | |||||
| node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name) | |||||
| reply = self.cache_store.get_stream_handler(Streams.TENSOR).get( | |||||
| {'name': tensor_name, | |||||
| 'node_type': node_type, | |||||
| 'shape': parsed_shape} | |||||
| ) | |||||
| reply['tensor_value']['name'] = name | |||||
| return reply | |||||
| def _get_tensor_name_and_type_by_ui_name(self, name): | |||||
| """ | |||||
| Get inner tensor name and type by UI name. | |||||
| Args: | |||||
| name (str): Node name shown in UI. | |||||
| Returns: | |||||
| str, full name of tensor. | |||||
| str, node type of tensor. | |||||
| """ | |||||
| node_name, slot = name.rsplit(':', 1) | |||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||||
| node_type = graph_stream.get_node_type(node_name) | |||||
| full_name = graph_stream.get_full_name(node_name) | |||||
| tensor_name = full_name + ':' + slot | |||||
| return node_type, tensor_name | |||||
| @staticmethod | |||||
| def validate_tensor_param(name, detail): | |||||
| """Validate params for retrieve tensor request.""" | |||||
| # validate name | |||||
| if not isinstance(name, str) or ':' not in name: | |||||
| log.error("Invalid tensor name. Received: %s", name) | |||||
| raise DebuggerParamValueError("Invalid tensor name.") | |||||
| # validate data | |||||
| if detail != 'data': | |||||
| log.error("Invalid detail value. Received: %s", detail) | |||||
| raise DebuggerParamValueError("Invalid detail value.") | |||||
| @staticmethod | |||||
| def parse_shape(shape): | |||||
| """Parse shape.""" | |||||
| if shape is None: | |||||
| return shape | |||||
| if not (isinstance(shape, str) and shape.startswith('[') and shape.endswith(']')): | |||||
| log.error("Invalid shape. Received: %s", shape) | |||||
| raise DebuggerParamValueError("Invalid shape.") | |||||
| shape = shape.strip('[]') | |||||
| if shape.count(':') > 2: | |||||
| log.error("Invalid shape. At most two dimensions are specified.") | |||||
| raise DebuggerParamValueError("Invalid shape.") | |||||
| parsed_shape = tuple( | |||||
| str_to_slice_or_int(dim) for dim in shape.split(',')) if shape else tuple() | |||||
| log.info("Parsed shape: %s from %s", parsed_shape, shape) | |||||
| return parsed_shape | |||||
| def _retrieve_watchpoint(self, filter_condition): | |||||
| """ | |||||
| Retrieve watchpoint. | |||||
| Args: | |||||
| filter_condition (dict): Filter condition. | |||||
| - watch_point_id (int): The id of watchoint. If not given, return all watchpoints. | |||||
| - name (str): The name of single node. | |||||
| - single_node (bool): If False, return the sub-layer of single node. If True, return | |||||
| the node list from root node to single node. | |||||
| Returns: | |||||
| dict, watch point list or relative graph. | |||||
| """ | |||||
| watchpoint_id = filter_condition.get('watch_point_id') | |||||
| if watchpoint_id is None: | |||||
| reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT).get() | |||||
| log.debug("Get condition of watchpoints.") | |||||
| else: | |||||
| reply = self._retrieve_node(filter_condition) | |||||
| log.debug("Get graph of %d-th watchpoint.", watchpoint_id) | |||||
| return reply | |||||
| def _retrieve_watchpoint_hit(self, filter_condition): | |||||
| """ | |||||
| Retrieve watchpoint hit. | |||||
| Args: | |||||
| filter_condition (dict): Filter condition. | |||||
| - name (str): The name of single node. | |||||
| - single_node (bool): If False, return the sub-layer of single node. If True, return | |||||
| the node list from root node to single node. | |||||
| Returns: | |||||
| dict, watch point list or relative graph. | |||||
| """ | |||||
| node_name = filter_condition.get('name') | |||||
| # get watchpoint hit list | |||||
| if node_name is None: | |||||
| reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get() | |||||
| return reply | |||||
| self._validate_leaf_name(node_name) | |||||
| # get tensor history | |||||
| reply = self._get_tensor_history(node_name) | |||||
| log.debug("Get tensor history for watchpoint hit node.") | |||||
| # get single graph | |||||
| if filter_condition.get('single_node'): | |||||
| graph = self._get_nodes_info(filter_condition) | |||||
| reply.update(graph) | |||||
| log.debug("Get tensor history for watchpoint hit node.") | |||||
| return reply | |||||
| def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None): | |||||
| """ | |||||
| Create watchpoint. | |||||
| Args: | |||||
| watch_condition (dict): The watch condition. | |||||
| - condition (str): Accept `INF` or `NAN`. | |||||
| - param (list[float]): Not defined yet. | |||||
| watch_nodes (list[str]): The list of node names. | |||||
| watch_point_id (int): The id of watchpoint. | |||||
| Returns: | |||||
| dict, the id of new watchpoint. | |||||
| """ | |||||
| log.info("Received create watchpoint request. WatchCondition: %s", watch_condition) | |||||
| metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | |||||
| if metadata_stream.state != ServerStatus.WAITING.value: | |||||
| log.error("Failed to create watchpoint as the MindSpore is not in waiting state.") | |||||
| raise DebuggerCreateWatchPointError( | |||||
| "Failed to create watchpoint as the MindSpore is not in waiting state." | |||||
| ) | |||||
| if metadata_stream.backend == 'GPU' and watch_condition.get('condition') == 'OVERFLOW': | |||||
| log.error("GPU doesn't support OVERFLOW watch condition.") | |||||
| raise DebuggerParamValueError("GPU doesn't support OVERFLOW watch condition.") | |||||
| watch_nodes = self._get_node_basic_infos(watch_nodes) | |||||
| watch_point_id = self.cache_store.get_stream_handler(Streams.WATCHPOINT).create_watchpoint( | |||||
| watch_condition, watch_nodes, watch_point_id) | |||||
| log.info("Create watchpoint %d", watch_point_id) | |||||
| return {'id': watch_point_id} | |||||
| def update_watchpoint(self, watch_point_id, watch_nodes, mode, name=None): | |||||
| """ | |||||
| Update watchpoint. | |||||
| Args: | |||||
| watch_point_id (int): The id of watchpoint. | |||||
| watch_nodes (list[str]): The list of node names. | |||||
| mode (int): The update operator on nodes. 0 for remove nodes from watch nodes. | |||||
| 1 for add nodes to watch nodes. | |||||
| name (str): The search name. Default: None. | |||||
| Returns: | |||||
| dict, empty response. | |||||
| """ | |||||
| if self.cache_store.get_stream_handler( | |||||
| Streams.METADATA).state != ServerStatus.WAITING.value: | |||||
| log.error("Failed to update watchpoint as the MindSpore is not in waiting state.") | |||||
| raise DebuggerUpdateWatchPointError( | |||||
| "Failed to update watchpoint as the MindSpore is not in waiting state." | |||||
| ) | |||||
| # validate | |||||
| if not watch_nodes or not watch_point_id: | |||||
| log.error("Invalid parameter for update watchpoint.") | |||||
| raise DebuggerParamValueError("Invalid parameter for update watchpoint.") | |||||
| # update watch node | |||||
| if name is not None: | |||||
| watch_nodes = self._get_watch_nodes_by_search(watch_nodes) | |||||
| elif mode == 1: | |||||
| watch_nodes = self._get_node_basic_infos(watch_nodes) | |||||
| self.cache_store.get_stream_handler(Streams.WATCHPOINT).update_watchpoint( | |||||
| watch_point_id, watch_nodes, mode) | |||||
| log.info("Update watchpoint with id: %d", watch_point_id) | |||||
| return {} | |||||
| def _get_watch_nodes_by_search(self, watch_nodes): | |||||
| """Get watched leaf nodes by search name.""" | |||||
| watched_leaf_nodes = [] | |||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||||
| for search_name in watch_nodes: | |||||
| search_nodes = graph_stream.get_searched_node_list() | |||||
| search_node_names = [ | |||||
| NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type) | |||||
| for node in search_nodes | |||||
| if node.name.startswith(search_name)] | |||||
| watched_leaf_nodes.extend(search_node_names) | |||||
| log.debug("Update nodes: %s", watched_leaf_nodes) | |||||
| return watched_leaf_nodes | |||||
| def delete_watchpoint(self, watch_point_id): | |||||
| """ | |||||
| Delete watchpoint. | |||||
| Args: | |||||
| watch_point_id (int): The id of watchpoint. | |||||
| Returns: | |||||
| dict, empty response. | |||||
| """ | |||||
| if self.cache_store.get_stream_handler( | |||||
| Streams.METADATA).state != ServerStatus.WAITING.value: | |||||
| log.error("Failed to delete watchpoint as the MindSpore is not in waiting state.") | |||||
| raise DebuggerDeleteWatchPointError( | |||||
| "Failed to delete watchpoint as the MindSpore is not in waiting state." | |||||
| ) | |||||
| self.cache_store.get_stream_handler(Streams.WATCHPOINT).delete_watchpoint( | |||||
| watch_point_id) | |||||
| log.info("Delete watchpoint with id: %d", watch_point_id) | |||||
| return {} | |||||
| def _get_node_basic_infos(self, node_names): | |||||
| """Get node info according to node names.""" | |||||
| if not node_names: | |||||
| return [] | |||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||||
| node_infos = [] | |||||
| for node_name in node_names: | |||||
| node_type = graph_stream.get_node_type(node_name) | |||||
| # optimizer later | |||||
| if node_type == NodeTypeEnum.AGGREGATION_SCOPE.value: | |||||
| sub_nodes = graph_stream.get_nodes(node_name) | |||||
| sub_infos = [NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type) | |||||
| for node in sub_nodes] | |||||
| node_infos.extend(sub_infos) | |||||
| continue | |||||
| full_name = graph_stream.get_full_name(node_name) | |||||
| node_infos.append(NodeBasicInfo(name=node_name, full_name=full_name, type=node_type)) | |||||
| return node_infos | |||||
| def control(self, params=None): | |||||
| """ | |||||
| Control the training process. | |||||
| Args: | |||||
| params (dict): The control params. | |||||
| - mode (str): Acceptable control command, including `continue`, | |||||
| `pause` and `terminate`. | |||||
| - level (str): The control granularity, `node` level or `step` level. | |||||
| Default: `step`. | |||||
| - steps (int): Specify the steps that training should run. | |||||
| Used when `level` is `step`. | |||||
| - name (str): Specify the name of the node. Used when `level` is `node`. | |||||
| Returns: | |||||
| dict, the response. | |||||
| """ | |||||
| log.info("Receive control request: %s.", params) | |||||
| mode = params.get('mode') | |||||
| metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) | |||||
| if mode == 'continue': | |||||
| reply = self._continue(metadata_stream, params) | |||||
| elif mode in ['pause', 'terminate']: | |||||
| mode_mapping = { | |||||
| 'pause': self._pause, | |||||
| 'terminate': self._terminate | |||||
| } | |||||
| reply = mode_mapping.get(mode)(metadata_stream) | |||||
| else: | |||||
| log.error("Invalid control mode %s", mode) | |||||
| raise DebuggerParamValueError("Invalid control mode.") | |||||
| return reply | |||||
| def _continue(self, metadata_stream, params): | |||||
| """ | |||||
| Send RunCMD to MindSpore. | |||||
| Args: | |||||
| metadata_stream (MetadataHandler): The metadata_handler | |||||
| params (dict): The control params. | |||||
| """ | |||||
| if metadata_stream.state != ServerStatus.WAITING.value: | |||||
| log.error("MindSpore is not ready to run. Current state is: %s", metadata_stream.state) | |||||
| raise DebuggerContinueError( | |||||
| "MindSpore is not ready to run or is running currently." | |||||
| ) | |||||
| metadata_stream.state = ServerStatus.RUNNING.value | |||||
| current_state = ServerStatus.RUNNING.value | |||||
| try: | |||||
| event = self._construct_run_event(params) | |||||
| self._send_watchpoints() | |||||
| self.cache_store.put_command(event) | |||||
| except MindInsightException as err: | |||||
| log.error("Failed to send run event.") | |||||
| log.exception(err) | |||||
| current_state = ServerStatus.WAITING.value | |||||
| metadata_stream.state = current_state | |||||
| raise DebuggerContinueError("Failed to send run command.") | |||||
| else: | |||||
| log.debug("Send the RunCMD to command queue.") | |||||
| return {'metadata': {'state': current_state}} | |||||
| def _validate_node_type(self, node_name): | |||||
| """Check the node type in node control.""" | |||||
| if not node_name: | |||||
| return | |||||
| node_type = self.cache_store.get_stream_handler(Streams.GRAPH).get_node_type(node_name) | |||||
| unsupported_types = [item.value for item in list(NodeTypeEnum)] | |||||
| if node_type in unsupported_types: | |||||
| log.error("Invalid node type. %s", node_name) | |||||
| raise DebuggerParamValueError(f"The type of node {node_name} is unsupported for " | |||||
| "continue to command.") | |||||
| def _construct_run_event(self, params): | |||||
| """ | |||||
| Construct run cmd from input control params. | |||||
| Args: | |||||
| params (dict): The control params. | |||||
| - level (str): The control granularity, `node` level or `step` level. | |||||
| Default: `step`. | |||||
| - steps (int): Specify the steps that training should run. | |||||
| Used when `level` is `step`. | |||||
| - full_name (str): Specify the name of the node. Used when `level` is `node`. | |||||
| Returns: | |||||
| EventReply, control event with run command. | |||||
| """ | |||||
| level = params.get('level', 'step') | |||||
| event = get_ack_reply() | |||||
| if level == 'step': | |||||
| steps = params.get('steps') | |||||
| if not steps: | |||||
| steps = 1 | |||||
| run_cmd = RunCMD(run_level='step', run_steps=steps) | |||||
| elif level == 'node': | |||||
| self._validate_node_type(params.get('name')) | |||||
| name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name( | |||||
| params['name']) | |||||
| if not name: | |||||
| name = '' | |||||
| run_cmd = RunCMD(run_level='node', node_name=name) | |||||
| else: | |||||
| log.error("Invalid Value. `level` should be `step` or `node`. Got %s", level) | |||||
| raise DebuggerParamValueError("level` should be `step` or `node`") | |||||
| event.run_cmd.CopyFrom(run_cmd) | |||||
| log.debug("Construct run event. %s", event) | |||||
| return event | |||||
| def _send_watchpoints(self): | |||||
| """Set watchpoints.""" | |||||
| watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) | |||||
| watchpoints = watchpoint_stream.get(filter_condition=True).get('watch_points') | |||||
| if watchpoints: | |||||
| for watchpoint in watchpoints: | |||||
| event = get_ack_reply() | |||||
| event.set_cmd.CopyFrom(watchpoint) | |||||
| self.cache_store.put_command(event) | |||||
| watchpoint_stream.sync_set_cmd() | |||||
| log.debug("Send SetCMD to MindSpore. %s", event) | |||||
| def _pause(self, metadata_stream): | |||||
| """ | |||||
| Pause the training. | |||||
| Args: | |||||
| metadata_stream (MetadataHandler): The metadata stream handler. | |||||
| """ | |||||
| if metadata_stream.state != ServerStatus.RUNNING.value: | |||||
| log.error("The MindSpore is not running.") | |||||
| raise DebuggerPauseError("The MindSpore is not running.") | |||||
| metadata_stream.state = 'waiting' | |||||
| event = get_ack_reply() | |||||
| event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0)) | |||||
| self.cache_store.put_command(event) | |||||
| log.debug("Send the Pause command") | |||||
| return {'metadata': {'state': 'waiting'}} | |||||
| def _terminate(self, metadata_stream): | |||||
| """ | |||||
| Terminate the training. | |||||
| Args: | |||||
| metadata_stream (MetadataHandler): The metadata stream handler. | |||||
| """ | |||||
| metadata_stream.state = 'pending' | |||||
| event = get_ack_reply() | |||||
| event.exit = True | |||||
| self.cache_store.put_command(event) | |||||
| log.debug("Send the ExitCMD.") | |||||
| return {'metadata': {'state': 'pending'}} | |||||
| def retrieve_node_by_bfs(self, node_name, ascend=False): | |||||
| """Get the graph and tensor history of the next node name according to node_name.""" | |||||
| log.info("Retrieve node <%s> by bfs, `ascend` is :%s", | |||||
| node_name, ascend) | |||||
| reply = {} | |||||
| graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) | |||||
| next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend) | |||||
| # no next node | |||||
| if next_node_name is None: | |||||
| return reply | |||||
| # add graph and tensor history for next node | |||||
| filter_condition = { | |||||
| 'name': next_node_name, | |||||
| 'single_node': True | |||||
| } | |||||
| search_graph = self._get_nodes_info(filter_condition) | |||||
| tensor_history = self._get_tensor_history(next_node_name) | |||||
| reply = {'name': next_node_name} | |||||
| reply.update(search_graph) | |||||
| reply.update(tensor_history) | |||||
| return reply | |||||
| @@ -0,0 +1,113 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package debugger; | |||||
| import "mindinsight/debugger/proto/ms_graph.proto"; | |||||
| service EventListener { | |||||
| rpc WaitCMD (Metadata) returns (EventReply) {}; | |||||
| rpc SendMetadata (Metadata) returns (EventReply) {}; | |||||
| rpc SendGraph (stream Chunk) returns (EventReply) {}; | |||||
| rpc SendTensors (stream TensorProto) returns (EventReply) {}; | |||||
| rpc SendWatchpointHits (stream WatchpointHit) returns (EventReply) {}; | |||||
| } | |||||
| message Metadata { | |||||
| string device_name = 1; | |||||
| int32 cur_step = 2; | |||||
| // define the backend is 'GPU' or 'Ascend' | |||||
| string backend = 3; | |||||
| // the full name of current node | |||||
| string cur_node = 4; | |||||
| // check if training is done. | |||||
| bool training_done = 5; | |||||
| } | |||||
| message Chunk { | |||||
| bytes buffer = 1; | |||||
| } | |||||
| message EventReply { | |||||
| enum Status { | |||||
| OK = 0; | |||||
| FAILED = 1; | |||||
| PENDING = 2; | |||||
| } | |||||
| Status status = 1; | |||||
| oneof cmd { | |||||
| bool exit = 2; | |||||
| RunCMD run_cmd = 3; | |||||
| SetCMD set_cmd = 4; | |||||
| ViewCMD view_cmd = 5; | |||||
| } | |||||
| } | |||||
| message RunCMD { | |||||
| // running level. 'step' or 'node' | |||||
| string run_level = 1; | |||||
| oneof cmd { | |||||
| int32 run_steps = 2; | |||||
| // the full name of next node | |||||
| string node_name = 3; | |||||
| } | |||||
| } | |||||
| message SetCMD { | |||||
| repeated WatchNode watch_nodes = 1; | |||||
| WatchCondition watch_condition = 2; | |||||
| bool delete = 3; | |||||
| int32 id = 4; | |||||
| } | |||||
| message ViewCMD { | |||||
| repeated TensorProto tensors = 1; | |||||
| } | |||||
| message WatchCondition { | |||||
| enum Condition { | |||||
| nan = 0; | |||||
| inf = 1; | |||||
| overflow = 2; | |||||
| max_gt = 3; | |||||
| max_lt = 4; | |||||
| min_gt = 5; | |||||
| min_lt = 6; | |||||
| max_min_gt = 7; | |||||
| max_min_lt = 8; | |||||
| mean_gt = 9; | |||||
| mean_lt = 10; | |||||
| } | |||||
| Condition condition = 1; | |||||
| float value = 2; // for between condition, there will be two values | |||||
| } | |||||
| message WatchNode { | |||||
| string node_name = 1; | |||||
| string node_type = 2; | |||||
| } | |||||
| message WatchpointHit { | |||||
| TensorProto tensor = 1; | |||||
| WatchCondition watch_condition = 2; | |||||
| int32 id = 3; | |||||
| } | |||||
| @@ -0,0 +1,683 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # Generated by the protocol buffer compiler. DO NOT EDIT! | |||||
| # source: mindinsight/debugger/proto/debug_grpc.proto | |||||
| from google.protobuf import descriptor as _descriptor | |||||
| from google.protobuf import message as _message | |||||
| from google.protobuf import reflection as _reflection | |||||
| from google.protobuf import symbol_database as _symbol_database | |||||
| # @@protoc_insertion_point(imports) | |||||
| _sym_db = _symbol_database.Default() | |||||
| from mindinsight.debugger.proto import ms_graph_pb2 as mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2 | |||||
| DESCRIPTOR = _descriptor.FileDescriptor( | |||||
| name='mindinsight/debugger/proto/debug_grpc.proto', | |||||
| package='debugger', | |||||
| syntax='proto3', | |||||
| serialized_options=None, | |||||
| serialized_pb=b'\n+mindinsight/debugger/proto/debug_grpc.proto\x12\x08\x64\x65\x62ugger\x1a)mindinsight/debugger/proto/ms_graph.proto\"k\n\x08Metadata\x12\x13\n\x0b\x64\x65vice_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ur_step\x18\x02 \x01(\x05\x12\x0f\n\x07\x62\x61\x63kend\x18\x03 \x01(\t\x12\x10\n\x08\x63ur_node\x18\x04 \x01(\t\x12\x15\n\rtraining_done\x18\x05 \x01(\x08\"\x17\n\x05\x43hunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\"\xec\x01\n\nEventReply\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.debugger.EventReply.Status\x12\x0e\n\x04\x65xit\x18\x02 \x01(\x08H\x00\x12#\n\x07run_cmd\x18\x03 \x01(\x0b\x32\x10.debugger.RunCMDH\x00\x12#\n\x07set_cmd\x18\x04 \x01(\x0b\x32\x10.debugger.SetCMDH\x00\x12%\n\x08view_cmd\x18\x05 \x01(\x0b\x32\x11.debugger.ViewCMDH\x00\")\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x05\n\x03\x63md\"L\n\x06RunCMD\x12\x11\n\trun_level\x18\x01 \x01(\t\x12\x13\n\trun_steps\x18\x02 \x01(\x05H\x00\x12\x13\n\tnode_name\x18\x03 \x01(\tH\x00\x42\x05\n\x03\x63md\"\x81\x01\n\x06SetCMD\x12(\n\x0bwatch_nodes\x18\x01 \x03(\x0b\x32\x13.debugger.WatchNode\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\x0e\n\x06\x64\x65lete\x18\x03 \x01(\x08\x12\n\n\x02id\x18\x04 \x01(\x05\"1\n\x07ViewCMD\x12&\n\x07tensors\x18\x01 \x03(\x0b\x32\x15.debugger.TensorProto\"\xee\x01\n\x0eWatchCondition\x12\x35\n\tcondition\x18\x01 \x01(\x0e\x32\".debugger.WatchCondition.Condition\x12\r\n\x05value\x18\x02 \x01(\x02\"\x95\x01\n\tCondition\x12\x07\n\x03nan\x10\x00\x12\x07\n\x03inf\x10\x01\x12\x0c\n\x08overflow\x10\x02\x12\n\n\x06max_gt\x10\x03\x12\n\n\x06max_lt\x10\x04\x12\n\n\x06min_gt\x10\x05\x12\n\n\x06min_lt\x10\x06\x12\x0e\n\nmax_min_gt\x10\x07\x12\x0e\n\nmax_min_lt\x10\x08\x12\x0b\n\x07mean_gt\x10\t\x12\x0b\n\x07mean_lt\x10\n\"1\n\tWatchNode\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x11\n\tnode_type\x18\x02 \x01(\t\"u\n\rWatchpointHit\x12%\n\x06tensor\x18\x01 \x01(\x0b\x32\x15.debugger.TensorProto\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\n\n\x02id\x18\x03 \x01(\x05\x32\xc3\x02\n\rEventListener\x12\x35\n\x07WaitCMD\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12:\n\x0cSendMetadata\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12\x36\n\tSendGraph\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x12>\n\x0bSendTensors\x12\x15.debugger.TensorProto\x1a\x14.debugger.EventReply\"\x00(\x01\x12G\n\x12SendWatchpointHits\x12\x17.debugger.WatchpointHit\x1a\x14.debugger.EventReply\"\x00(\x01\x62\x06proto3' | |||||
| , | |||||
| dependencies=[mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.DESCRIPTOR,]) | |||||
| _EVENTREPLY_STATUS = _descriptor.EnumDescriptor( | |||||
| name='Status', | |||||
| full_name='debugger.EventReply.Status', | |||||
| filename=None, | |||||
| file=DESCRIPTOR, | |||||
| values=[ | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='OK', index=0, number=0, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='FAILED', index=1, number=1, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='PENDING', index=2, number=2, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| ], | |||||
| containing_type=None, | |||||
| serialized_options=None, | |||||
| serialized_start=423, | |||||
| serialized_end=464, | |||||
| ) | |||||
| _sym_db.RegisterEnumDescriptor(_EVENTREPLY_STATUS) | |||||
| _WATCHCONDITION_CONDITION = _descriptor.EnumDescriptor( | |||||
| name='Condition', | |||||
| full_name='debugger.WatchCondition.Condition', | |||||
| filename=None, | |||||
| file=DESCRIPTOR, | |||||
| values=[ | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='nan', index=0, number=0, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='inf', index=1, number=1, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='overflow', index=2, number=2, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='max_gt', index=3, number=3, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='max_lt', index=4, number=4, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='min_gt', index=5, number=5, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='min_lt', index=6, number=6, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='max_min_gt', index=7, number=7, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='max_min_lt', index=8, number=8, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='mean_gt', index=9, number=9, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| _descriptor.EnumValueDescriptor( | |||||
| name='mean_lt', index=10, number=10, | |||||
| serialized_options=None, | |||||
| type=None), | |||||
| ], | |||||
| containing_type=None, | |||||
| serialized_options=None, | |||||
| serialized_start=824, | |||||
| serialized_end=973, | |||||
| ) | |||||
| _sym_db.RegisterEnumDescriptor(_WATCHCONDITION_CONDITION) | |||||
| _METADATA = _descriptor.Descriptor( | |||||
| name='Metadata', | |||||
| full_name='debugger.Metadata', | |||||
| filename=None, | |||||
| file=DESCRIPTOR, | |||||
| containing_type=None, | |||||
| fields=[ | |||||
| _descriptor.FieldDescriptor( | |||||
| name='device_name', full_name='debugger.Metadata.device_name', index=0, | |||||
| number=1, type=9, cpp_type=9, label=1, | |||||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='cur_step', full_name='debugger.Metadata.cur_step', index=1, | |||||
| number=2, type=5, cpp_type=1, label=1, | |||||
| has_default_value=False, default_value=0, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='backend', full_name='debugger.Metadata.backend', index=2, | |||||
| number=3, type=9, cpp_type=9, label=1, | |||||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='cur_node', full_name='debugger.Metadata.cur_node', index=3, | |||||
| number=4, type=9, cpp_type=9, label=1, | |||||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='training_done', full_name='debugger.Metadata.training_done', index=4, | |||||
| number=5, type=8, cpp_type=7, label=1, | |||||
| has_default_value=False, default_value=False, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| ], | |||||
| extensions=[ | |||||
| ], | |||||
| nested_types=[], | |||||
| enum_types=[ | |||||
| ], | |||||
| serialized_options=None, | |||||
| is_extendable=False, | |||||
| syntax='proto3', | |||||
| extension_ranges=[], | |||||
| oneofs=[ | |||||
| ], | |||||
| serialized_start=100, | |||||
| serialized_end=207, | |||||
| ) | |||||
| _CHUNK = _descriptor.Descriptor( | |||||
| name='Chunk', | |||||
| full_name='debugger.Chunk', | |||||
| filename=None, | |||||
| file=DESCRIPTOR, | |||||
| containing_type=None, | |||||
| fields=[ | |||||
| _descriptor.FieldDescriptor( | |||||
| name='buffer', full_name='debugger.Chunk.buffer', index=0, | |||||
| number=1, type=12, cpp_type=9, label=1, | |||||
| has_default_value=False, default_value=b"", | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| ], | |||||
| extensions=[ | |||||
| ], | |||||
| nested_types=[], | |||||
| enum_types=[ | |||||
| ], | |||||
| serialized_options=None, | |||||
| is_extendable=False, | |||||
| syntax='proto3', | |||||
| extension_ranges=[], | |||||
| oneofs=[ | |||||
| ], | |||||
| serialized_start=209, | |||||
| serialized_end=232, | |||||
| ) | |||||
| _EVENTREPLY = _descriptor.Descriptor( | |||||
| name='EventReply', | |||||
| full_name='debugger.EventReply', | |||||
| filename=None, | |||||
| file=DESCRIPTOR, | |||||
| containing_type=None, | |||||
| fields=[ | |||||
| _descriptor.FieldDescriptor( | |||||
| name='status', full_name='debugger.EventReply.status', index=0, | |||||
| number=1, type=14, cpp_type=8, label=1, | |||||
| has_default_value=False, default_value=0, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='exit', full_name='debugger.EventReply.exit', index=1, | |||||
| number=2, type=8, cpp_type=7, label=1, | |||||
| has_default_value=False, default_value=False, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='run_cmd', full_name='debugger.EventReply.run_cmd', index=2, | |||||
| number=3, type=11, cpp_type=10, label=1, | |||||
| has_default_value=False, default_value=None, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='set_cmd', full_name='debugger.EventReply.set_cmd', index=3, | |||||
| number=4, type=11, cpp_type=10, label=1, | |||||
| has_default_value=False, default_value=None, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='view_cmd', full_name='debugger.EventReply.view_cmd', index=4, | |||||
| number=5, type=11, cpp_type=10, label=1, | |||||
| has_default_value=False, default_value=None, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| ], | |||||
| extensions=[ | |||||
| ], | |||||
| nested_types=[], | |||||
| enum_types=[ | |||||
| _EVENTREPLY_STATUS, | |||||
| ], | |||||
| serialized_options=None, | |||||
| is_extendable=False, | |||||
| syntax='proto3', | |||||
| extension_ranges=[], | |||||
| oneofs=[ | |||||
| _descriptor.OneofDescriptor( | |||||
| name='cmd', full_name='debugger.EventReply.cmd', | |||||
| index=0, containing_type=None, fields=[]), | |||||
| ], | |||||
| serialized_start=235, | |||||
| serialized_end=471, | |||||
| ) | |||||
| _RUNCMD = _descriptor.Descriptor( | |||||
| name='RunCMD', | |||||
| full_name='debugger.RunCMD', | |||||
| filename=None, | |||||
| file=DESCRIPTOR, | |||||
| containing_type=None, | |||||
| fields=[ | |||||
| _descriptor.FieldDescriptor( | |||||
| name='run_level', full_name='debugger.RunCMD.run_level', index=0, | |||||
| number=1, type=9, cpp_type=9, label=1, | |||||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='run_steps', full_name='debugger.RunCMD.run_steps', index=1, | |||||
| number=2, type=5, cpp_type=1, label=1, | |||||
| has_default_value=False, default_value=0, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='node_name', full_name='debugger.RunCMD.node_name', index=2, | |||||
| number=3, type=9, cpp_type=9, label=1, | |||||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| ], | |||||
| extensions=[ | |||||
| ], | |||||
| nested_types=[], | |||||
| enum_types=[ | |||||
| ], | |||||
| serialized_options=None, | |||||
| is_extendable=False, | |||||
| syntax='proto3', | |||||
| extension_ranges=[], | |||||
| oneofs=[ | |||||
| _descriptor.OneofDescriptor( | |||||
| name='cmd', full_name='debugger.RunCMD.cmd', | |||||
| index=0, containing_type=None, fields=[]), | |||||
| ], | |||||
| serialized_start=473, | |||||
| serialized_end=549, | |||||
| ) | |||||
| _SETCMD = _descriptor.Descriptor( | |||||
| name='SetCMD', | |||||
| full_name='debugger.SetCMD', | |||||
| filename=None, | |||||
| file=DESCRIPTOR, | |||||
| containing_type=None, | |||||
| fields=[ | |||||
| _descriptor.FieldDescriptor( | |||||
| name='watch_nodes', full_name='debugger.SetCMD.watch_nodes', index=0, | |||||
| number=1, type=11, cpp_type=10, label=3, | |||||
| has_default_value=False, default_value=[], | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='watch_condition', full_name='debugger.SetCMD.watch_condition', index=1, | |||||
| number=2, type=11, cpp_type=10, label=1, | |||||
| has_default_value=False, default_value=None, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='delete', full_name='debugger.SetCMD.delete', index=2, | |||||
| number=3, type=8, cpp_type=7, label=1, | |||||
| has_default_value=False, default_value=False, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='id', full_name='debugger.SetCMD.id', index=3, | |||||
| number=4, type=5, cpp_type=1, label=1, | |||||
| has_default_value=False, default_value=0, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| ], | |||||
| extensions=[ | |||||
| ], | |||||
| nested_types=[], | |||||
| enum_types=[ | |||||
| ], | |||||
| serialized_options=None, | |||||
| is_extendable=False, | |||||
| syntax='proto3', | |||||
| extension_ranges=[], | |||||
| oneofs=[ | |||||
| ], | |||||
| serialized_start=552, | |||||
| serialized_end=681, | |||||
| ) | |||||
| _VIEWCMD = _descriptor.Descriptor( | |||||
| name='ViewCMD', | |||||
| full_name='debugger.ViewCMD', | |||||
| filename=None, | |||||
| file=DESCRIPTOR, | |||||
| containing_type=None, | |||||
| fields=[ | |||||
| _descriptor.FieldDescriptor( | |||||
| name='tensors', full_name='debugger.ViewCMD.tensors', index=0, | |||||
| number=1, type=11, cpp_type=10, label=3, | |||||
| has_default_value=False, default_value=[], | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| ], | |||||
| extensions=[ | |||||
| ], | |||||
| nested_types=[], | |||||
| enum_types=[ | |||||
| ], | |||||
| serialized_options=None, | |||||
| is_extendable=False, | |||||
| syntax='proto3', | |||||
| extension_ranges=[], | |||||
| oneofs=[ | |||||
| ], | |||||
| serialized_start=683, | |||||
| serialized_end=732, | |||||
| ) | |||||
| _WATCHCONDITION = _descriptor.Descriptor( | |||||
| name='WatchCondition', | |||||
| full_name='debugger.WatchCondition', | |||||
| filename=None, | |||||
| file=DESCRIPTOR, | |||||
| containing_type=None, | |||||
| fields=[ | |||||
| _descriptor.FieldDescriptor( | |||||
| name='condition', full_name='debugger.WatchCondition.condition', index=0, | |||||
| number=1, type=14, cpp_type=8, label=1, | |||||
| has_default_value=False, default_value=0, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='value', full_name='debugger.WatchCondition.value', index=1, | |||||
| number=2, type=2, cpp_type=6, label=1, | |||||
| has_default_value=False, default_value=float(0), | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| ], | |||||
| extensions=[ | |||||
| ], | |||||
| nested_types=[], | |||||
| enum_types=[ | |||||
| _WATCHCONDITION_CONDITION, | |||||
| ], | |||||
| serialized_options=None, | |||||
| is_extendable=False, | |||||
| syntax='proto3', | |||||
| extension_ranges=[], | |||||
| oneofs=[ | |||||
| ], | |||||
| serialized_start=735, | |||||
| serialized_end=973, | |||||
| ) | |||||
| _WATCHNODE = _descriptor.Descriptor( | |||||
| name='WatchNode', | |||||
| full_name='debugger.WatchNode', | |||||
| filename=None, | |||||
| file=DESCRIPTOR, | |||||
| containing_type=None, | |||||
| fields=[ | |||||
| _descriptor.FieldDescriptor( | |||||
| name='node_name', full_name='debugger.WatchNode.node_name', index=0, | |||||
| number=1, type=9, cpp_type=9, label=1, | |||||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='node_type', full_name='debugger.WatchNode.node_type', index=1, | |||||
| number=2, type=9, cpp_type=9, label=1, | |||||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| ], | |||||
| extensions=[ | |||||
| ], | |||||
| nested_types=[], | |||||
| enum_types=[ | |||||
| ], | |||||
| serialized_options=None, | |||||
| is_extendable=False, | |||||
| syntax='proto3', | |||||
| extension_ranges=[], | |||||
| oneofs=[ | |||||
| ], | |||||
| serialized_start=975, | |||||
| serialized_end=1024, | |||||
| ) | |||||
| _WATCHPOINTHIT = _descriptor.Descriptor( | |||||
| name='WatchpointHit', | |||||
| full_name='debugger.WatchpointHit', | |||||
| filename=None, | |||||
| file=DESCRIPTOR, | |||||
| containing_type=None, | |||||
| fields=[ | |||||
| _descriptor.FieldDescriptor( | |||||
| name='tensor', full_name='debugger.WatchpointHit.tensor', index=0, | |||||
| number=1, type=11, cpp_type=10, label=1, | |||||
| has_default_value=False, default_value=None, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='watch_condition', full_name='debugger.WatchpointHit.watch_condition', index=1, | |||||
| number=2, type=11, cpp_type=10, label=1, | |||||
| has_default_value=False, default_value=None, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| _descriptor.FieldDescriptor( | |||||
| name='id', full_name='debugger.WatchpointHit.id', index=2, | |||||
| number=3, type=5, cpp_type=1, label=1, | |||||
| has_default_value=False, default_value=0, | |||||
| message_type=None, enum_type=None, containing_type=None, | |||||
| is_extension=False, extension_scope=None, | |||||
| serialized_options=None, file=DESCRIPTOR), | |||||
| ], | |||||
| extensions=[ | |||||
| ], | |||||
| nested_types=[], | |||||
| enum_types=[ | |||||
| ], | |||||
| serialized_options=None, | |||||
| is_extendable=False, | |||||
| syntax='proto3', | |||||
| extension_ranges=[], | |||||
| oneofs=[ | |||||
| ], | |||||
| serialized_start=1026, | |||||
| serialized_end=1143, | |||||
| ) | |||||
| _EVENTREPLY.fields_by_name['status'].enum_type = _EVENTREPLY_STATUS | |||||
| _EVENTREPLY.fields_by_name['run_cmd'].message_type = _RUNCMD | |||||
| _EVENTREPLY.fields_by_name['set_cmd'].message_type = _SETCMD | |||||
| _EVENTREPLY.fields_by_name['view_cmd'].message_type = _VIEWCMD | |||||
| _EVENTREPLY_STATUS.containing_type = _EVENTREPLY | |||||
| _EVENTREPLY.oneofs_by_name['cmd'].fields.append( | |||||
| _EVENTREPLY.fields_by_name['exit']) | |||||
| _EVENTREPLY.fields_by_name['exit'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd'] | |||||
| _EVENTREPLY.oneofs_by_name['cmd'].fields.append( | |||||
| _EVENTREPLY.fields_by_name['run_cmd']) | |||||
| _EVENTREPLY.fields_by_name['run_cmd'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd'] | |||||
| _EVENTREPLY.oneofs_by_name['cmd'].fields.append( | |||||
| _EVENTREPLY.fields_by_name['set_cmd']) | |||||
| _EVENTREPLY.fields_by_name['set_cmd'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd'] | |||||
| _EVENTREPLY.oneofs_by_name['cmd'].fields.append( | |||||
| _EVENTREPLY.fields_by_name['view_cmd']) | |||||
| _EVENTREPLY.fields_by_name['view_cmd'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd'] | |||||
| _RUNCMD.oneofs_by_name['cmd'].fields.append( | |||||
| _RUNCMD.fields_by_name['run_steps']) | |||||
| _RUNCMD.fields_by_name['run_steps'].containing_oneof = _RUNCMD.oneofs_by_name['cmd'] | |||||
| _RUNCMD.oneofs_by_name['cmd'].fields.append( | |||||
| _RUNCMD.fields_by_name['node_name']) | |||||
| _RUNCMD.fields_by_name['node_name'].containing_oneof = _RUNCMD.oneofs_by_name['cmd'] | |||||
| _SETCMD.fields_by_name['watch_nodes'].message_type = _WATCHNODE | |||||
| _SETCMD.fields_by_name['watch_condition'].message_type = _WATCHCONDITION | |||||
| _VIEWCMD.fields_by_name['tensors'].message_type = mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2._TENSORPROTO | |||||
| _WATCHCONDITION.fields_by_name['condition'].enum_type = _WATCHCONDITION_CONDITION | |||||
| _WATCHCONDITION_CONDITION.containing_type = _WATCHCONDITION | |||||
| _WATCHPOINTHIT.fields_by_name['tensor'].message_type = mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2._TENSORPROTO | |||||
| _WATCHPOINTHIT.fields_by_name['watch_condition'].message_type = _WATCHCONDITION | |||||
| DESCRIPTOR.message_types_by_name['Metadata'] = _METADATA | |||||
| DESCRIPTOR.message_types_by_name['Chunk'] = _CHUNK | |||||
| DESCRIPTOR.message_types_by_name['EventReply'] = _EVENTREPLY | |||||
| DESCRIPTOR.message_types_by_name['RunCMD'] = _RUNCMD | |||||
| DESCRIPTOR.message_types_by_name['SetCMD'] = _SETCMD | |||||
| DESCRIPTOR.message_types_by_name['ViewCMD'] = _VIEWCMD | |||||
| DESCRIPTOR.message_types_by_name['WatchCondition'] = _WATCHCONDITION | |||||
| DESCRIPTOR.message_types_by_name['WatchNode'] = _WATCHNODE | |||||
| DESCRIPTOR.message_types_by_name['WatchpointHit'] = _WATCHPOINTHIT | |||||
| _sym_db.RegisterFileDescriptor(DESCRIPTOR) | |||||
| Metadata = _reflection.GeneratedProtocolMessageType('Metadata', (_message.Message,), { | |||||
| 'DESCRIPTOR' : _METADATA, | |||||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||||
| # @@protoc_insertion_point(class_scope:debugger.Metadata) | |||||
| }) | |||||
| _sym_db.RegisterMessage(Metadata) | |||||
| Chunk = _reflection.GeneratedProtocolMessageType('Chunk', (_message.Message,), { | |||||
| 'DESCRIPTOR' : _CHUNK, | |||||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||||
| # @@protoc_insertion_point(class_scope:debugger.Chunk) | |||||
| }) | |||||
| _sym_db.RegisterMessage(Chunk) | |||||
| EventReply = _reflection.GeneratedProtocolMessageType('EventReply', (_message.Message,), { | |||||
| 'DESCRIPTOR' : _EVENTREPLY, | |||||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||||
| # @@protoc_insertion_point(class_scope:debugger.EventReply) | |||||
| }) | |||||
| _sym_db.RegisterMessage(EventReply) | |||||
| RunCMD = _reflection.GeneratedProtocolMessageType('RunCMD', (_message.Message,), { | |||||
| 'DESCRIPTOR' : _RUNCMD, | |||||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||||
| # @@protoc_insertion_point(class_scope:debugger.RunCMD) | |||||
| }) | |||||
| _sym_db.RegisterMessage(RunCMD) | |||||
| SetCMD = _reflection.GeneratedProtocolMessageType('SetCMD', (_message.Message,), { | |||||
| 'DESCRIPTOR' : _SETCMD, | |||||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||||
| # @@protoc_insertion_point(class_scope:debugger.SetCMD) | |||||
| }) | |||||
| _sym_db.RegisterMessage(SetCMD) | |||||
| ViewCMD = _reflection.GeneratedProtocolMessageType('ViewCMD', (_message.Message,), { | |||||
| 'DESCRIPTOR' : _VIEWCMD, | |||||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||||
| # @@protoc_insertion_point(class_scope:debugger.ViewCMD) | |||||
| }) | |||||
| _sym_db.RegisterMessage(ViewCMD) | |||||
| WatchCondition = _reflection.GeneratedProtocolMessageType('WatchCondition', (_message.Message,), { | |||||
| 'DESCRIPTOR' : _WATCHCONDITION, | |||||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||||
| # @@protoc_insertion_point(class_scope:debugger.WatchCondition) | |||||
| }) | |||||
| _sym_db.RegisterMessage(WatchCondition) | |||||
| WatchNode = _reflection.GeneratedProtocolMessageType('WatchNode', (_message.Message,), { | |||||
| 'DESCRIPTOR' : _WATCHNODE, | |||||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||||
| # @@protoc_insertion_point(class_scope:debugger.WatchNode) | |||||
| }) | |||||
| _sym_db.RegisterMessage(WatchNode) | |||||
| WatchpointHit = _reflection.GeneratedProtocolMessageType('WatchpointHit', (_message.Message,), { | |||||
| 'DESCRIPTOR' : _WATCHPOINTHIT, | |||||
| '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' | |||||
| # @@protoc_insertion_point(class_scope:debugger.WatchpointHit) | |||||
| }) | |||||
| _sym_db.RegisterMessage(WatchpointHit) | |||||
| _EVENTLISTENER = _descriptor.ServiceDescriptor( | |||||
| name='EventListener', | |||||
| full_name='debugger.EventListener', | |||||
| file=DESCRIPTOR, | |||||
| index=0, | |||||
| serialized_options=None, | |||||
| serialized_start=1146, | |||||
| serialized_end=1469, | |||||
| methods=[ | |||||
| _descriptor.MethodDescriptor( | |||||
| name='WaitCMD', | |||||
| full_name='debugger.EventListener.WaitCMD', | |||||
| index=0, | |||||
| containing_service=None, | |||||
| input_type=_METADATA, | |||||
| output_type=_EVENTREPLY, | |||||
| serialized_options=None, | |||||
| ), | |||||
| _descriptor.MethodDescriptor( | |||||
| name='SendMetadata', | |||||
| full_name='debugger.EventListener.SendMetadata', | |||||
| index=1, | |||||
| containing_service=None, | |||||
| input_type=_METADATA, | |||||
| output_type=_EVENTREPLY, | |||||
| serialized_options=None, | |||||
| ), | |||||
| _descriptor.MethodDescriptor( | |||||
| name='SendGraph', | |||||
| full_name='debugger.EventListener.SendGraph', | |||||
| index=2, | |||||
| containing_service=None, | |||||
| input_type=_CHUNK, | |||||
| output_type=_EVENTREPLY, | |||||
| serialized_options=None, | |||||
| ), | |||||
| _descriptor.MethodDescriptor( | |||||
| name='SendTensors', | |||||
| full_name='debugger.EventListener.SendTensors', | |||||
| index=3, | |||||
| containing_service=None, | |||||
| input_type=mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2._TENSORPROTO, | |||||
| output_type=_EVENTREPLY, | |||||
| serialized_options=None, | |||||
| ), | |||||
| _descriptor.MethodDescriptor( | |||||
| name='SendWatchpointHits', | |||||
| full_name='debugger.EventListener.SendWatchpointHits', | |||||
| index=4, | |||||
| containing_service=None, | |||||
| input_type=_WATCHPOINTHIT, | |||||
| output_type=_EVENTREPLY, | |||||
| serialized_options=None, | |||||
| ), | |||||
| ]) | |||||
| _sym_db.RegisterServiceDescriptor(_EVENTLISTENER) | |||||
| DESCRIPTOR.services_by_name['EventListener'] = _EVENTLISTENER | |||||
| # @@protoc_insertion_point(module_scope) | |||||
| @@ -0,0 +1,193 @@ | |||||
| # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! | |||||
| import grpc | |||||
| from mindinsight.debugger.proto import debug_grpc_pb2 as mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2 | |||||
| from mindinsight.debugger.proto import ms_graph_pb2 as mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2 | |||||
| class EventListenerStub(object): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| def __init__(self, channel): | |||||
| """Constructor. | |||||
| Args: | |||||
| channel: A grpc.Channel. | |||||
| """ | |||||
| self.WaitCMD = channel.unary_unary( | |||||
| '/debugger.EventListener/WaitCMD', | |||||
| request_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, | |||||
| response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||||
| ) | |||||
| self.SendMetadata = channel.unary_unary( | |||||
| '/debugger.EventListener/SendMetadata', | |||||
| request_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, | |||||
| response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||||
| ) | |||||
| self.SendGraph = channel.stream_unary( | |||||
| '/debugger.EventListener/SendGraph', | |||||
| request_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString, | |||||
| response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||||
| ) | |||||
| self.SendTensors = channel.stream_unary( | |||||
| '/debugger.EventListener/SendTensors', | |||||
| request_serializer=mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.SerializeToString, | |||||
| response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||||
| ) | |||||
| self.SendWatchpointHits = channel.stream_unary( | |||||
| '/debugger.EventListener/SendWatchpointHits', | |||||
| request_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.SerializeToString, | |||||
| response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||||
| ) | |||||
| class EventListenerServicer(object): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| def WaitCMD(self, request, context): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||||
| context.set_details('Method not implemented!') | |||||
| raise NotImplementedError('Method not implemented!') | |||||
| def SendMetadata(self, request, context): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||||
| context.set_details('Method not implemented!') | |||||
| raise NotImplementedError('Method not implemented!') | |||||
| def SendGraph(self, request_iterator, context): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||||
| context.set_details('Method not implemented!') | |||||
| raise NotImplementedError('Method not implemented!') | |||||
| def SendTensors(self, request_iterator, context): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||||
| context.set_details('Method not implemented!') | |||||
| raise NotImplementedError('Method not implemented!') | |||||
| def SendWatchpointHits(self, request_iterator, context): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| context.set_code(grpc.StatusCode.UNIMPLEMENTED) | |||||
| context.set_details('Method not implemented!') | |||||
| raise NotImplementedError('Method not implemented!') | |||||
| def add_EventListenerServicer_to_server(servicer, server): | |||||
| rpc_method_handlers = { | |||||
| 'WaitCMD': grpc.unary_unary_rpc_method_handler( | |||||
| servicer.WaitCMD, | |||||
| request_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.FromString, | |||||
| response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString, | |||||
| ), | |||||
| 'SendMetadata': grpc.unary_unary_rpc_method_handler( | |||||
| servicer.SendMetadata, | |||||
| request_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.FromString, | |||||
| response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString, | |||||
| ), | |||||
| 'SendGraph': grpc.stream_unary_rpc_method_handler( | |||||
| servicer.SendGraph, | |||||
| request_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.FromString, | |||||
| response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString, | |||||
| ), | |||||
| 'SendTensors': grpc.stream_unary_rpc_method_handler( | |||||
| servicer.SendTensors, | |||||
| request_deserializer=mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.FromString, | |||||
| response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString, | |||||
| ), | |||||
| 'SendWatchpointHits': grpc.stream_unary_rpc_method_handler( | |||||
| servicer.SendWatchpointHits, | |||||
| request_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.FromString, | |||||
| response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString, | |||||
| ), | |||||
| } | |||||
| generic_handler = grpc.method_handlers_generic_handler( | |||||
| 'debugger.EventListener', rpc_method_handlers) | |||||
| server.add_generic_rpc_handlers((generic_handler,)) | |||||
| # This class is part of an EXPERIMENTAL API. | |||||
| class EventListener(object): | |||||
| """Missing associated documentation comment in .proto file""" | |||||
| @staticmethod | |||||
| def WaitCMD(request, | |||||
| target, | |||||
| options=(), | |||||
| channel_credentials=None, | |||||
| call_credentials=None, | |||||
| compression=None, | |||||
| wait_for_ready=None, | |||||
| timeout=None, | |||||
| metadata=None): | |||||
| return grpc.experimental.unary_unary(request, target, '/debugger.EventListener/WaitCMD', | |||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, | |||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||||
| options, channel_credentials, | |||||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| @staticmethod | |||||
| def SendMetadata(request, | |||||
| target, | |||||
| options=(), | |||||
| channel_credentials=None, | |||||
| call_credentials=None, | |||||
| compression=None, | |||||
| wait_for_ready=None, | |||||
| timeout=None, | |||||
| metadata=None): | |||||
| return grpc.experimental.unary_unary(request, target, '/debugger.EventListener/SendMetadata', | |||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, | |||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||||
| options, channel_credentials, | |||||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| @staticmethod | |||||
| def SendGraph(request_iterator, | |||||
| target, | |||||
| options=(), | |||||
| channel_credentials=None, | |||||
| call_credentials=None, | |||||
| compression=None, | |||||
| wait_for_ready=None, | |||||
| timeout=None, | |||||
| metadata=None): | |||||
| return grpc.experimental.stream_unary(request_iterator, target, '/debugger.EventListener/SendGraph', | |||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString, | |||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||||
| options, channel_credentials, | |||||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| @staticmethod | |||||
| def SendTensors(request_iterator, | |||||
| target, | |||||
| options=(), | |||||
| channel_credentials=None, | |||||
| call_credentials=None, | |||||
| compression=None, | |||||
| wait_for_ready=None, | |||||
| timeout=None, | |||||
| metadata=None): | |||||
| return grpc.experimental.stream_unary(request_iterator, target, '/debugger.EventListener/SendTensors', | |||||
| mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.SerializeToString, | |||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||||
| options, channel_credentials, | |||||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| @staticmethod | |||||
| def SendWatchpointHits(request_iterator, | |||||
| target, | |||||
| options=(), | |||||
| channel_credentials=None, | |||||
| call_credentials=None, | |||||
| compression=None, | |||||
| wait_for_ready=None, | |||||
| timeout=None, | |||||
| metadata=None): | |||||
| return grpc.experimental.stream_unary(request_iterator, target, '/debugger.EventListener/SendWatchpointHits', | |||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.SerializeToString, | |||||
| mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, | |||||
| options, channel_credentials, | |||||
| call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
| @@ -0,0 +1,322 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| syntax = "proto2"; | |||||
| package debugger; | |||||
| // Versioning | |||||
| enum Version { | |||||
| // unknown version | |||||
| UNKNOWWN_VERSION = 0; | |||||
| // Initial version (IR VERSION 1), published on Sep 23, 2019 | |||||
| IR_VERSION = 0x0000000000000001; | |||||
| } | |||||
| // Data type definition | |||||
| enum DataType { | |||||
| DT_UNDEFINED = 0; | |||||
| // Basic types. | |||||
| DT_BOOL = 1; // bool | |||||
| DT_INT8 = 2; // int8_t | |||||
| DT_INT16 = 3; // int16_t | |||||
| DT_INT32 = 4; // int32_t | |||||
| DT_INT64 = 5; // int64_t | |||||
| DT_UINT8 = 6; // uint8_t | |||||
| DT_UINT16 = 7; // uint16_t | |||||
| DT_UINT32 = 8; // uint32_t | |||||
| DT_UINT64 = 9; // uint64_t | |||||
| DT_FLOAT16 = 10; // float 16 | |||||
| DT_FLOAT32 = 11; // float 32 | |||||
| DT_FLOAT64 = 12; // float 64 | |||||
| DT_STRING = 13; // string | |||||
| DT_TENSOR = 14; // tensor | |||||
| DT_GRAPH = 15; // graph | |||||
| // list type | |||||
| DT_BOOLS = 16; // list of bool | |||||
| DT_INTS8 = 17; // list of int8_t | |||||
| DT_INTS16 = 18; // list of int16_t | |||||
| DT_INTS32 = 19; // list of int32_t | |||||
| DT_INTS64 = 20; // list of int64_t | |||||
| DT_UINTS8 = 21; // list of uint8_t | |||||
| DT_UINTS16 = 22; // list of uint16_t | |||||
| DT_UINTS32 = 23; // list of uint32_t | |||||
| DT_UINTS64 = 24; // list of uint64_t | |||||
| DT_FLOATS16 = 25; // list of float16 | |||||
| DT_FLOATS32 = 26; // list of float32 | |||||
| DT_FLOATS64 = 27; // list of float64 | |||||
| DT_STRINGS = 28; // list of string | |||||
| DT_TENSORS = 29; // list of tensor | |||||
| DT_GRAPHS = 30; // list of graph | |||||
| DT_TUPLE = 31; // tuple | |||||
| DT_LIST = 32; // list | |||||
| DT_DICT = 33; // dictionary | |||||
| // other types | |||||
| DT_NONE = 34; // None | |||||
| DT_SYM_INST = 35; // Symbolic Key Instance | |||||
| // type related type | |||||
| DT_BASE_INT = 36; // type generic int | |||||
| DT_BASE_UINT = 37; // type generate unsigned int | |||||
| DT_BASE_FLOAT = 38; // type generate float | |||||
| DT_TYPE = 39; // type type | |||||
| DT_ANYTHING = 40; // type anything | |||||
| DT_REFKEY = 41; // type refkey | |||||
| DT_REF = 42; // type ref | |||||
| } | |||||
| // Value definition for attribute value or parameter default value | |||||
| message ValueProto { | |||||
| // data type of value | |||||
| optional DataType dtype = 1; // discriminator that indicates which field below is in use | |||||
| // Exactly ONE of the following fields must be present for this version of the IR | |||||
| optional bool bool_val = 2; // bool | |||||
| optional int64 int_val = 3; // int | |||||
| optional uint64 uint_val = 4; // uint | |||||
| optional float float_val = 5; // float | |||||
| optional double double_val = 6; // double | |||||
| optional string str_val = 7; // string | |||||
| optional TensorProto tensor_val = 8; // tensor value | |||||
| optional GraphProto graph = 9; // graph | |||||
| repeated bool bool_vals = 10; // list of bool | |||||
| repeated int64 int_vals = 11; // list of int | |||||
| repeated uint64 uint_vals = 12; // list of uint | |||||
| repeated float float_vals = 13; // list of float | |||||
| repeated double double_vals = 14; // list of double | |||||
| repeated string str_vals = 15; // list of string | |||||
| repeated TensorProto tensor_vals = 16; // list of tensor value | |||||
| repeated GraphProto graphs = 17; // list of graph | |||||
| // tuple or list | |||||
| repeated ValueProto values = 18; // tuple, list of value | |||||
| // dictionary | |||||
| repeated NamedValueProto dict_val = 19; // dictionary info | |||||
| // filed for type type | |||||
| optional TypeProto type_val = 20; // type type info | |||||
| } | |||||
| message AttributeProto { | |||||
| optional string name = 1; // attribute name | |||||
| optional ValueProto value = 2; // attribute value | |||||
| } | |||||
| message NamedValueProto { | |||||
| optional string key = 1; // attribute name | |||||
| optional ValueProto value = 2; // attribute value | |||||
| } | |||||
| // Defines a tensor shape. | |||||
| message TensorShapeProto { | |||||
| // One dimension of the tensor. | |||||
| message Dimension { | |||||
| // Size of the tensor in that dimension. | |||||
| // This value must be >= -1, but values of -1 are reserved for "unknown" | |||||
| // shapes (values of -1 mean "unknown" dimension). | |||||
| optional int64 size = 1; | |||||
| // Optional name of the tensor dimension. | |||||
| optional string name = 2; | |||||
| }; | |||||
| repeated Dimension dim = 1; | |||||
| } | |||||
| // Types for graph input(parameter) and output | |||||
| message TypeProto { | |||||
| message Tensor { | |||||
| // This field MUST have a valid DataType value except DT_TENSOR | |||||
| optional DataType elem_type = 1; | |||||
| optional TensorShapeProto shape = 2; // for scalar, this field is not set | |||||
| } | |||||
| // tuple type | |||||
| message Sequence { | |||||
| // The type and optional shape of elements of the tuple. | |||||
| repeated TypeProto elem_types = 1; | |||||
| }; | |||||
| // data type | |||||
| optional DataType data_type = 1; | |||||
| oneof value { | |||||
| // The type of a tensor. | |||||
| Tensor tensor_type = 2; | |||||
| // The type of a tuple. | |||||
| Sequence sequence_type = 3; | |||||
| } | |||||
| } | |||||
| // Defines information on graph parameters, including the name, the type, and | |||||
| // the default value of parameter if exists. | |||||
| message ParameterProto { | |||||
| optional string name = 1; // parameter name | |||||
| optional TypeProto type = 2; // parameter type | |||||
| optional ValueProto default_val = 3; // default value of parameter if exists | |||||
| } | |||||
| // Defines graph output information | |||||
| message OutputProto { | |||||
| optional string name = 1; // output node name | |||||
| optional TypeProto type = 2; // output node type | |||||
| } | |||||
| // Define node input information | |||||
| message InputProto { | |||||
| enum EdgeType { | |||||
| DATA_EDGE = 0; // data edge | |||||
| CONTROL_EDGE = 1; // control edge | |||||
| } | |||||
| optional string name = 1; | |||||
| optional EdgeType type = 2; | |||||
| } | |||||
| // Nodes | |||||
| // | |||||
| // Computation graphs are made up of a DAG of nodes, which represent what is | |||||
| // commonly called a "layer" or "pipeline stage" in machine learning frameworks. | |||||
| // | |||||
| // For example, it can be a node of type "Conv" that takes in an image, a filter | |||||
| // tensor and a bias tensor, and produces the convolved output. | |||||
| message NodeProto { | |||||
| repeated InputProto input = 1; // namespace Value | |||||
| optional string name = 2; // namespace Value | |||||
| // The symbolic identifier of the Operator to execute. | |||||
| optional string op_type = 3; // namespace Operator | |||||
| // The domain of the OperatorSet that specifies the operator named by op_type. | |||||
| optional string scope = 4; // namespace Domain | |||||
| // Additional named attributes. | |||||
| repeated AttributeProto attribute = 5; | |||||
| // Optional type info of this node | |||||
| optional TypeProto output_type = 6; | |||||
| // other fields for debug | |||||
| optional uint64 output_i = 7; | |||||
| // full name with scope | |||||
| optional string full_name = 8; | |||||
| } | |||||
| // Models | |||||
| // | |||||
| // ModelProto is a top-level file/container format for bundling a ML model and | |||||
| // associating its computation graph with metadata. | |||||
| // | |||||
| // The semantics of the model are described by the associated GraphProto. | |||||
| message ModelProto { | |||||
| // ir version | |||||
| optional int64 ir_version = 1; | |||||
| // Domain name of the model. | |||||
| // We use reverse domain names as name space indicators. For example: | |||||
| // `com.facebook.fair` or `com.microsoft.cognitiveservices` | |||||
| // | |||||
| // Together with `model_version` and GraphProto.name, this forms the unique identity of | |||||
| // the graph. | |||||
| optional string domain = 2; | |||||
| // The version of the graph encoded. See Version enum below. | |||||
| optional int64 model_version = 3; | |||||
| // The parameterized graph that is evaluated to execute the model. | |||||
| optional GraphProto graph = 4; | |||||
| // metadata info of opeartors | |||||
| optional OperatorSetProto metadata_operators = 5; | |||||
| }; | |||||
| message OperatorProto { | |||||
| optional string name = 1; // used as key, must be distinct | |||||
| optional bytes config = 2; // operator config info | |||||
| optional bytes obj_info = 3; // operator related object info, e.g. content of operator binary or name | |||||
| }; | |||||
| message OperatorSetProto { | |||||
| repeated OperatorProto operators = 1; | |||||
| optional string summary = 2; // summary info of operators, e.g. file position of operators file | |||||
| } | |||||
| // Graphs | |||||
| // | |||||
| // A graph defines the computational logic of a model and is comprised of a parameterized | |||||
| // list of nodes that form a directed acyclic graph based on their inputs and outputs. | |||||
| // This is the equivalent of the "network" or "graph" in many deep learning | |||||
| // frameworks. | |||||
| message GraphProto { | |||||
| // The nodes in the graph, sorted topologically. | |||||
| repeated NodeProto node = 1; | |||||
| // The name of the graph. | |||||
| optional string name = 2; // namespace Graph | |||||
| // The parameters(inputs) and outputs of the graph. | |||||
| repeated ParameterProto parameters = 3; | |||||
| repeated OutputProto outputs = 4; | |||||
| // Constants used in this graph | |||||
| repeated NamedValueProto const_vals = 5; | |||||
| } | |||||
| // Tensors | |||||
| // | |||||
| // A serialized tensor value. | |||||
| message TensorProto { | |||||
| // The node name of the tensor. | |||||
| optional string node_name = 1; | |||||
| // The slot of the tensor in its node. | |||||
| optional string slot = 2; | |||||
| // The serialized tensor content. | |||||
| optional bytes tensor_content = 3; | |||||
| // The shape of the tensor. | |||||
| repeated int64 dims = 4; | |||||
| // The data type of the tensor. | |||||
| // This field MUST have a valid DataType value except DT_TENSOR | |||||
| optional DataType data_type = 5; | |||||
| // If the tensor content transferring is finished. | |||||
| optional bool finished = 6; | |||||
| // The iteration of the tensor. Supported: "prev" or leave empty. | |||||
| optional string iter = 7; | |||||
| // If the tensor name should be truncated. | |||||
| optional bool truncate = 8; | |||||
| } | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| @@ -0,0 +1,289 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """This file is used to define the basic graph.""" | |||||
| from collections import deque | |||||
| from mindinsight.datavisual.data_transform.graph.msgraph import MSGraph | |||||
| from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum | |||||
| from mindinsight.debugger.common.exceptions.exceptions import \ | |||||
| DebuggerNodeNotInGraphError, DebuggerParamValueError | |||||
| from mindinsight.debugger.common.log import logger as log | |||||
| from .node import NodeTree | |||||
| class DebuggerGraph(MSGraph): | |||||
| """The `DebuggerGraph` object provides interfaces to describe a debugger graph.""" | |||||
| def __init__(self): | |||||
| super(DebuggerGraph, self).__init__() | |||||
| self._node_tree = None | |||||
| def get_node_name_by_full_name(self, full_name): | |||||
| """Get node name by full names.""" | |||||
| inner_name = self._full_name_map_name.get(full_name, '') | |||||
| if not inner_name: | |||||
| log.warning("Node %s does not find the relative inner node name.", full_name) | |||||
| return inner_name | |||||
| def get_full_name_by_node_name(self, node_name): | |||||
| """Get full name by node name for leaf nodes.""" | |||||
| node = self._normal_node_map.get(node_name) | |||||
| if not node: | |||||
| log.warning("Node %s is not leaf node.", node_name) | |||||
| return node.full_name if node else '' | |||||
| def get_nodes(self, searched_node_list): | |||||
| """ | |||||
| Search node names by a given pattern. | |||||
| Args: | |||||
| searched_node_list (list[Node]): A list of leaf nodes that | |||||
| matches the given search pattern. | |||||
| Returns: | |||||
| A list of dict including the searched nodes. | |||||
| [{ | |||||
| "name": "Default", | |||||
| "type": "name_scope", | |||||
| "nodes": [{ | |||||
| "name": "Default/Conv2D1", | |||||
| "type": "name_scope", | |||||
| "nodes": [{ | |||||
| ... | |||||
| }] | |||||
| }] | |||||
| }, | |||||
| { | |||||
| "name": "Gradients", | |||||
| "type": "name_scope", | |||||
| "nodes": [{ | |||||
| "name": "Gradients/Default", | |||||
| "type": "name_scope", | |||||
| "nodes": [{ | |||||
| ... | |||||
| }] | |||||
| }] | |||||
| """ | |||||
| # save the node in the NodeTree | |||||
| self._node_tree = NodeTree() | |||||
| for node in searched_node_list: | |||||
| self._build_node_tree(node.name, node.type) | |||||
| # get the searched nodes in the NodeTree and reorganize them | |||||
| searched_list = [] | |||||
| self._traverse_node_tree(self._node_tree, searched_list) | |||||
| return searched_list | |||||
| def search_nodes_by_pattern(self, pattern): | |||||
| """ | |||||
| Search node names by a given pattern. | |||||
| Args: | |||||
| pattern (Union[str, None]): The pattern of the node to search, | |||||
| if None, return all node names. | |||||
| Returns: | |||||
| list[(str, str)], a list of tuple (node name, node type). | |||||
| """ | |||||
| if pattern is not None: | |||||
| pattern = pattern.lower() | |||||
| searched_nodes = [ | |||||
| node for name, node in self._leaf_nodes.items() | |||||
| if pattern in name.lower() | |||||
| ] | |||||
| else: | |||||
| searched_nodes = [node for name, node in self._leaf_nodes.items()] | |||||
| return searched_nodes | |||||
| def _build_node_tree(self, node_name, node_type): | |||||
| """Build node tree.""" | |||||
| scope_names = node_name.split('/') | |||||
| cur_node = self._node_tree | |||||
| for scope_name in scope_names[:-1]: | |||||
| sub_node = cur_node.get(scope_name) | |||||
| if not sub_node: | |||||
| sub_node = cur_node.add(scope_name) | |||||
| cur_node = sub_node | |||||
| cur_node.add(scope_names[-1], node_type) | |||||
| def _traverse_node_tree(self, cur_node, search_node_list): | |||||
| """Traverse the watch nodes and update the total watched node list.""" | |||||
| if not cur_node.get_children(): | |||||
| return | |||||
| for _, sub_node in cur_node.get_children(): | |||||
| sub_nodes = [] | |||||
| self._traverse_node_tree(sub_node, sub_nodes) | |||||
| sub_node_dict = { | |||||
| 'name': sub_node.node_name, | |||||
| 'type': sub_node.node_type, | |||||
| 'nodes': sub_nodes | |||||
| } | |||||
| search_node_list.append(sub_node_dict) | |||||
| def get_node_type(self, node_name): | |||||
| """ | |||||
| Get the type of the node. | |||||
| Args: | |||||
| node_name (str): The full name of the node with its scope. | |||||
| Returns: | |||||
| A string, leaf or name_scope. | |||||
| """ | |||||
| if node_name and not self.exist_node(name=node_name): | |||||
| raise DebuggerNodeNotInGraphError(node_name=node_name) | |||||
| node = self._leaf_nodes.get(node_name) | |||||
| if node is not None: | |||||
| node_type = node.type | |||||
| else: | |||||
| node_type = NodeTypeEnum.NAME_SCOPE.value | |||||
| return node_type | |||||
| def get_tensor_history(self, node_name, depth=0): | |||||
| """ | |||||
| Get the tensor history of a specified node. | |||||
| Args: | |||||
| node_name (str): The debug name of the node. | |||||
| depth (int): The number of layers the user wants to trace. Default is 0. | |||||
| Returns: | |||||
| list, a list of the traced tensors' name and node type, | |||||
| arranged in order from leaf node to root node. | |||||
| int, the number of output tensors. | |||||
| """ | |||||
| node = self._leaf_nodes.get(node_name) | |||||
| tensor_history = self._get_tensor_infos_of_node(node) | |||||
| cur_outputs_nums = len(tensor_history) | |||||
| cur_depth = 0 | |||||
| trace_list = deque([(node, cur_depth)]) | |||||
| while trace_list: | |||||
| cur_node, cur_depth = trace_list.popleft() | |||||
| tensors_info = self._get_input_tensors_of_node(cur_node) | |||||
| if tensors_info: | |||||
| tensor_history.extend(tensors_info) | |||||
| if cur_depth < depth: | |||||
| for name in cur_node.input.keys(): | |||||
| trace_list.append((self._leaf_nodes[name], cur_depth + 1)) | |||||
| return tensor_history, cur_outputs_nums | |||||
| @staticmethod | |||||
| def _get_tensor_infos_of_node(cur_node, slot=None): | |||||
| """Get tensors info of specified node.""" | |||||
| tensors_info = [] | |||||
| if slot is None: | |||||
| slots = range(cur_node.output_nums) | |||||
| elif slot >= 0: | |||||
| slots = [slot] | |||||
| else: | |||||
| log.info("Skip get tensor info for %s:%s.", cur_node.name, slot) | |||||
| return tensors_info | |||||
| for num in slots: | |||||
| tensor_info = { | |||||
| 'name': cur_node.name + ':' + str(num), | |||||
| 'full_name': cur_node.full_name + ':' + str(num), | |||||
| 'node_type': cur_node.type | |||||
| } | |||||
| tensors_info.append(tensor_info) | |||||
| return tensors_info | |||||
| def _get_input_tensors_of_node(self, cur_node): | |||||
| """Get input tensors of node.""" | |||||
| tensors_info = [] | |||||
| for name in cur_node.input.keys(): | |||||
| node = self._leaf_nodes.get(name) | |||||
| tensor_info = self._get_tensor_infos_of_node(node) | |||||
| tensors_info.extend(tensor_info) | |||||
| return tensors_info | |||||
| def get_bfs_order(self): | |||||
| """ | |||||
| Traverse the graph in order of breath-first search. | |||||
| Returns: | |||||
| list, including the leaf nodes arranged in BFS order. | |||||
| """ | |||||
| root = self.get_default_root() | |||||
| log.info('Randomly choose node %s as root to do BFS.', root.name) | |||||
| bfs_order = [] | |||||
| self.get_bfs_graph(root.name, bfs_order) | |||||
| length = len(self._leaf_nodes.keys()) | |||||
| # Find rest un-traversed nodes | |||||
| for node_name, _ in self._leaf_nodes.items(): | |||||
| if node_name not in bfs_order: | |||||
| self.get_bfs_graph(node_name, bfs_order) | |||||
| if len(bfs_order) != length: | |||||
| log.error("The length of bfs and leaf nodes are not equal.") | |||||
| msg = "Not all nodes are traversed!" | |||||
| raise DebuggerParamValueError(msg) | |||||
| return bfs_order | |||||
| def get_bfs_graph(self, node_name, bfs_order): | |||||
| """ | |||||
| Traverse the graph in order of breath-first search. | |||||
| Returns: | |||||
| list, including the leaf nodes arranged in BFS order. | |||||
| """ | |||||
| temp_list = deque() | |||||
| temp_list.append(node_name) | |||||
| while temp_list: | |||||
| node_name = temp_list.popleft() | |||||
| node = self._leaf_nodes.get(node_name) | |||||
| if not node: | |||||
| log.warning('Cannot find node %s in graph. Ignored.', node_name) | |||||
| continue | |||||
| bfs_order.append(node_name) | |||||
| if node.input: | |||||
| for name in node.input.keys(): | |||||
| if name not in temp_list and name not in bfs_order: | |||||
| temp_list.append(name) | |||||
| if node.output: | |||||
| for name in node.output.keys(): | |||||
| if name not in temp_list and name not in bfs_order: | |||||
| temp_list.append(name) | |||||
| def get_default_root(self): | |||||
| """ | |||||
| Get a node as default root for BFS in graph. Using the | |||||
| leaf node with the smallest node id as the default root. | |||||
| Returns: | |||||
| str, the name of the default root. | |||||
| """ | |||||
| default_root = None | |||||
| for _, item in self._leaf_nodes.items(): | |||||
| if item.node_id == '1': | |||||
| default_root = item | |||||
| break | |||||
| if default_root is None: | |||||
| log.error("Abnormal graph. Invalid node for BFS.") | |||||
| msg = 'Abnormal graph. Invalid node for BFS.' | |||||
| raise DebuggerParamValueError(msg) | |||||
| return default_root | |||||
| @@ -0,0 +1,61 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| This file is used to define the node of graph and associated base types. | |||||
| """ | |||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | |||||
| from mindinsight.debugger.common.log import logger as log | |||||
| class NodeTree: | |||||
| """A class for building a node tree.""" | |||||
| def __init__(self, node_name='', node_type=None): | |||||
| self.node_name = node_name | |||||
| self._node_type = node_type | |||||
| self._children = {} | |||||
| @property | |||||
| def node_type(self): | |||||
| """The property of node type.""" | |||||
| return self._node_type | |||||
| @node_type.setter | |||||
| def node_type(self, value): | |||||
| """Set the node type.""" | |||||
| self._node_type = value | |||||
| def add(self, name, node_type=None): | |||||
| """Add sub node.""" | |||||
| sub_name = '/'.join([self.node_name, name]) if self.node_name else name | |||||
| sub_node = NodeTree(sub_name, node_type) | |||||
| self._children[name] = sub_node | |||||
| return sub_node | |||||
| def get(self, sub_name): | |||||
| """Get sub node.""" | |||||
| return self._children.get(sub_name) | |||||
| def get_children(self): | |||||
| """Get all childrens.""" | |||||
| for name_scope, sub_node in self._children.items(): | |||||
| yield name_scope, sub_node | |||||
| def remove(self, sub_name): | |||||
| """Remove sub node.""" | |||||
| try: | |||||
| self._children.pop(sub_name) | |||||
| except KeyError as err: | |||||
| log.error("Failed to find node %s. %s", sub_name, err) | |||||
| raise DebuggerParamValueError("Failed to find node {}".format(sub_name)) | |||||
| @@ -0,0 +1,233 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """The definition of tensor stream.""" | |||||
| from abc import abstractmethod, ABC | |||||
| import numpy as np | |||||
| from mindinsight.utils.tensor import TensorUtils | |||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | |||||
| from mindinsight.debugger.common.log import logger as log | |||||
| from mindinsight.debugger.common.utils import NUMPY_TYPE_MAP | |||||
| from mindinsight.debugger.proto.ms_graph_pb2 import DataType | |||||
| class BaseTensor(ABC): | |||||
| """Tensor data structure.""" | |||||
| def __init__(self, step=0): | |||||
| self._step = step | |||||
| @property | |||||
| @abstractmethod | |||||
| def name(self): | |||||
| """The property of tensor name.""" | |||||
| @property | |||||
| @abstractmethod | |||||
| def dtype(self): | |||||
| """The property of tensor dtype.""" | |||||
| @property | |||||
| @abstractmethod | |||||
| def shape(self): | |||||
| """The property of tensor shape.""" | |||||
| @property | |||||
| @abstractmethod | |||||
| def value(self): | |||||
| """The property of tensor shape.""" | |||||
| @abstractmethod | |||||
| def get_tensor_value_by_shape(self, shape=None): | |||||
| """Get tensor value by shape.""" | |||||
| def _to_dict(self): | |||||
| """Get tensor info in dict format.""" | |||||
| res = { | |||||
| 'full_name': self.name, | |||||
| 'step': self._step, | |||||
| 'dtype': self.dtype, | |||||
| 'shape': self.shape, | |||||
| 'has_prev_step': False | |||||
| } | |||||
| return res | |||||
| def get_basic_info(self): | |||||
| """Return basic info about tensor info.""" | |||||
| if not self.shape: | |||||
| value = self.value | |||||
| else: | |||||
| value = 'click to view' | |||||
| res = self._to_dict() | |||||
| res['value'] = value | |||||
| return res | |||||
| def get_full_info(self, shape=None): | |||||
| """Get tensor info with value.""" | |||||
| res = self._to_dict() | |||||
| value_info = self.get_tensor_serializable_value_by_shape(shape) | |||||
| res.update(value_info) | |||||
| return res | |||||
| class OpTensor(BaseTensor): | |||||
| """Tensor data structure for operator Node.""" | |||||
| max_number_data_show_on_ui = 100000 | |||||
| def __init__(self, tensor_proto, step=0): | |||||
| # the type of tensor_proto is TensorProto | |||||
| super(OpTensor, self).__init__(step) | |||||
| self._tensor_proto = tensor_proto | |||||
| self._value = self.generate_value(tensor_proto) | |||||
| @property | |||||
| def name(self): | |||||
| """The property of tensor name.""" | |||||
| node_name = self._tensor_proto.node_name | |||||
| slot = self._tensor_proto.slot | |||||
| return ':'.join([node_name, slot]) | |||||
| @property | |||||
| def dtype(self): | |||||
| """The property of tensor dtype.""" | |||||
| tensor_type = DataType.Name(self._tensor_proto.data_type) | |||||
| return tensor_type | |||||
| @property | |||||
| def shape(self): | |||||
| """The property of tensor shape.""" | |||||
| return list(self._tensor_proto.dims) | |||||
| @property | |||||
| def value(self): | |||||
| """The property of tensor value.""" | |||||
| tensor_value = None | |||||
| if self._value is not None: | |||||
| tensor_value = self._value.tolist() | |||||
| return tensor_value | |||||
| @property | |||||
| def numpy_value(self): | |||||
| """The property of tensor value in numpy type.""" | |||||
| return self._value | |||||
| def generate_value(self, tensor_proto): | |||||
| """Generate tensor value from proto.""" | |||||
| tensor_value = None | |||||
| if tensor_proto.tensor_content: | |||||
| tensor_value = tensor_proto.tensor_content | |||||
| np_type = NUMPY_TYPE_MAP.get(self.dtype) | |||||
| tensor_value = np.frombuffer(tensor_value, dtype=np_type) | |||||
| tensor_value = tensor_value.reshape(self.shape) | |||||
| return tensor_value | |||||
| def get_tensor_serializable_value_by_shape(self, shape=None): | |||||
| """ | |||||
| Get tensor value info by shape. | |||||
| Args: | |||||
| shape (tuple): The specified range of tensor value. | |||||
| Returns: | |||||
| dict, the specified tensor value and value statistics. | |||||
| """ | |||||
| tensor_value = self.get_tensor_value_by_shape(shape) | |||||
| res = {} | |||||
| if isinstance(tensor_value, np.ndarray): | |||||
| statistics = TensorUtils.get_statistics_from_tensor(tensor_value) | |||||
| res['statistics'] = TensorUtils.get_statistics_dict(statistics) | |||||
| res['value'] = tensor_value.tolist() | |||||
| return res | |||||
| return res | |||||
| def get_tensor_value_by_shape(self, shape=None): | |||||
| """ | |||||
| Get tensor value by shape. | |||||
| Args: | |||||
| shape (tuple): The specified shape. | |||||
| Returns: | |||||
| Union[None, str, numpy.ndarray], the sub-tensor. | |||||
| """ | |||||
| if self._value is None: | |||||
| log.warning("%s has no value yet.", self.name) | |||||
| return None | |||||
| if shape is None or not isinstance(shape, tuple): | |||||
| log.info("Get the whole tensor value with shape is %s", shape) | |||||
| return self._value | |||||
| if len(shape) != len(self.shape): | |||||
| log.error("Invalid shape. Received: %s, tensor shape: %s", shape, self.shape) | |||||
| raise DebuggerParamValueError("Invalid shape. Shape unmatched.") | |||||
| try: | |||||
| value = self._value[shape] | |||||
| except IndexError as err: | |||||
| log.error("Invalid shape. Received: %s, tensor shape: %s", shape, self.shape) | |||||
| log.exception(err) | |||||
| raise DebuggerParamValueError("Invalid shape. Shape unmatched.") | |||||
| if isinstance(value, np.ndarray): | |||||
| if value.size > self.max_number_data_show_on_ui: | |||||
| value = "Too large to show." | |||||
| log.info("The tensor size is %s, which is too large to show on UI.") | |||||
| else: | |||||
| value = np.asarray(value) | |||||
| return value | |||||
| class ConstTensor(BaseTensor): | |||||
| """Tensor data structure for Const Node.""" | |||||
| def __init__(self, const_proto): | |||||
| # the type of const_proto is NamedValueProto | |||||
| super(ConstTensor, self).__init__() | |||||
| self._const_proto = const_proto | |||||
| def set_step(self, step): | |||||
| """Set step value.""" | |||||
| self._step = step | |||||
| @property | |||||
| def name(self): | |||||
| """The property of tensor name.""" | |||||
| return self._const_proto.key + ':0' | |||||
| @property | |||||
| def dtype(self): | |||||
| """The property of tensor dtype.""" | |||||
| return DataType.Name(self._const_proto.value.dtype) | |||||
| @property | |||||
| def shape(self): | |||||
| """The property of tensor shape.""" | |||||
| return [] | |||||
| @property | |||||
| def value(self): | |||||
| """The property of tensor shape.""" | |||||
| fields = self._const_proto.value.ListFields() | |||||
| if len(fields) != 2: | |||||
| log.warning("Unexpected const proto <%s>.\n Please check offline.", self._const_proto) | |||||
| for field_name, field_value in fields: | |||||
| if field_name != 'dtype': | |||||
| return field_value | |||||
| return None | |||||
| def get_tensor_value_by_shape(self, shape=None): | |||||
| """Get tensor info with value.""" | |||||
| if shape is not None: | |||||
| log.warning("Invalid shape for const value.") | |||||
| return self.value | |||||
| @@ -0,0 +1,300 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Define the watchpoint stream.""" | |||||
| from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum | |||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | |||||
| from mindinsight.debugger.common.log import logger as log | |||||
| from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD, WatchCondition | |||||
| WATCHPOINT_CONDITION_MAPPING = { | |||||
| 'INF': WatchCondition.Condition.inf, | |||||
| 'NAN': WatchCondition.Condition.nan, | |||||
| 'OVERFLOW': WatchCondition.Condition.overflow, | |||||
| 'MAX_GT': WatchCondition.Condition.max_gt, | |||||
| 'MAX_LT': WatchCondition.Condition.max_lt, | |||||
| 'MIN_GT': WatchCondition.Condition.min_gt, | |||||
| 'MIN_LT': WatchCondition.Condition.min_lt, | |||||
| 'MAX_MIN_GT': WatchCondition.Condition.max_min_gt, | |||||
| 'MAX_MIN_LT': WatchCondition.Condition.max_min_lt, | |||||
| 'MEAN_GT': WatchCondition.Condition.mean_gt, | |||||
| 'MEAN_LT': WatchCondition.Condition.mean_lt | |||||
| } | |||||
| class WatchNodeTree: | |||||
| """The WatchNode Node Structure.""" | |||||
| NOT_WATCH = 0 # the scope node and the nodes below are not watched | |||||
| PARTIAL_WATCH = 1 # at least one node under the scope node is not watched | |||||
| TOTAL_WATCH = 2 # the scope node and the nodes below are all watched | |||||
| def __init__(self, node_name='', node_type=None, full_name='', watch_status=1): | |||||
| self._node_name = node_name | |||||
| self._full_name = full_name | |||||
| self._node_type = self._translate_node_type(node_type) | |||||
| self._watch_status = watch_status | |||||
| self._children = {} | |||||
| @property | |||||
| def node_name(self): | |||||
| """The property of node name.""" | |||||
| return self._node_name | |||||
| @property | |||||
| def full_name(self): | |||||
| """The property of node name.""" | |||||
| return self._full_name | |||||
| @property | |||||
| def node_type(self): | |||||
| """The property of node type.""" | |||||
| return self._node_type | |||||
| @node_type.setter | |||||
| def node_type(self, value): | |||||
| """Set the node type.""" | |||||
| self._node_type = self._translate_node_type(value) | |||||
| @property | |||||
| def watch_status(self): | |||||
| """The property of watch status about current node.""" | |||||
| return self._watch_status | |||||
| def enable_watch_status(self): | |||||
| """The property of watch status about current node.""" | |||||
| self._watch_status = WatchNodeTree.TOTAL_WATCH | |||||
| @staticmethod | |||||
| def _translate_node_type(node_type): | |||||
| """Translate node type to watch node type.""" | |||||
| if not node_type or node_type == NodeTypeEnum.NAME_SCOPE.value or \ | |||||
| node_type == NodeTypeEnum.AGGREGATION_SCOPE.value: | |||||
| return 'scope' | |||||
| return 'leaf' | |||||
| def get(self, sub_name): | |||||
| """Get sub node.""" | |||||
| return self._children.get(sub_name) | |||||
| def get_children(self): | |||||
| """Get all childrens.""" | |||||
| for name_scope, sub_watch_node in self._children.items(): | |||||
| yield name_scope, sub_watch_node | |||||
| def add_node(self, node_name, node_type, full_name=''): | |||||
| """ | |||||
| Add watch node to watch node tree. | |||||
| Args: | |||||
| node_name (str): The node name. | |||||
| node_type (str): The node type. | |||||
| full_name (str): The full name of node. | |||||
| """ | |||||
| log.debug("Add node %s with type: %s, full_name: %s", node_name, node_type, full_name) | |||||
| scope_names = node_name.split('/', 1) | |||||
| if len(scope_names) == 1: | |||||
| if not self.get(node_name): | |||||
| self.add(node_name, node_type, full_name, watch_status=WatchNodeTree.TOTAL_WATCH) | |||||
| else: | |||||
| self.get(node_name).enable_watch_status() | |||||
| return | |||||
| scope_name, sub_names = scope_names | |||||
| sub_tree = self.get(scope_name) | |||||
| if not sub_tree: | |||||
| sub_tree = self.add(scope_name, watch_status=1) | |||||
| sub_tree.add_node(sub_names, node_type, full_name) | |||||
| def add(self, name, node_type=None, full_name='', watch_status=1): | |||||
| """Add sub WatchPointTree.""" | |||||
| sub_name = '/'.join([self._node_name, name]) if self._node_name else name | |||||
| sub_tree = WatchNodeTree(sub_name, node_type, full_name, watch_status) | |||||
| self._children[name] = sub_tree | |||||
| return sub_tree | |||||
| def remove_node(self, node_name): | |||||
| """Remove sub node from current tree.""" | |||||
| log.debug("Remove %s", node_name) | |||||
| scope_names = node_name.split('/', 1) | |||||
| sub_tree_name = scope_names[0] | |||||
| sub_tree = self._children.get(sub_tree_name) | |||||
| if not sub_tree: | |||||
| log.error("Failed to find node %s in WatchNodeTree.", sub_tree_name) | |||||
| raise DebuggerParamValueError("Failed to find node {}".format(sub_tree_name)) | |||||
| if len(scope_names) > 1: | |||||
| sub_tree.remove_node(scope_names[1]) | |||||
| if sub_tree.watch_status == WatchNodeTree.NOT_WATCH or len(scope_names) == 1: | |||||
| self._children.pop(sub_tree_name) | |||||
| self._watch_status = WatchNodeTree.PARTIAL_WATCH if self._children else \ | |||||
| WatchNodeTree.NOT_WATCH | |||||
| class Watchpoint: | |||||
| """ | |||||
| The class of watchpoint stream. | |||||
| Args: | |||||
| watchpoint_id (int): The id of Watchpoint. | |||||
| watch_condition (dict): The condition of Watchpoint. | |||||
| - condition (str): Accept `INF` or `NAN`. | |||||
| - param (list[float]): Not defined yet. | |||||
| """ | |||||
| def __init__(self, watchpoint_id, watch_condition): | |||||
| self._id = watchpoint_id | |||||
| self._condition = watch_condition | |||||
| self._watch_node = WatchNodeTree() | |||||
| @property | |||||
| def watchpoint_id(self): | |||||
| """The property of watchpoint id.""" | |||||
| return self._id | |||||
| @property | |||||
| def nodes(self): | |||||
| """The property of watch nodes.""" | |||||
| return self._watch_node | |||||
| @property | |||||
| def condition(self): | |||||
| """The property of watch condition.""" | |||||
| return self._condition | |||||
| def copy_nodes_from(self, other_watchpoint): | |||||
| """ | |||||
| Copy nodes from other watchpoint. | |||||
| Args: | |||||
| other_watchpoint (Watchpoint): Other watchpoint. | |||||
| """ | |||||
| self._watch_node = other_watchpoint.nodes | |||||
| def add_nodes(self, nodes): | |||||
| """Add node into watchcpoint.""" | |||||
| if not nodes: | |||||
| log.warning("Add empty nodes.") | |||||
| return | |||||
| if not isinstance(nodes, list): | |||||
| nodes = [nodes] | |||||
| for node in nodes: | |||||
| self._watch_node.add_node(node.name, node.type, node.full_name) | |||||
| def remove_nodes(self, nodes): | |||||
| """Remove nodes from watchpoint.""" | |||||
| if not nodes: | |||||
| return | |||||
| if not isinstance(nodes, list): | |||||
| nodes = [nodes] | |||||
| for node in nodes: | |||||
| node_name = node.split(':')[0] | |||||
| self._watch_node.remove_node(node_name) | |||||
| def get_node_status(self, node_name, node_type, full_name): | |||||
| """Judge if the node is in watch nodes.""" | |||||
| scope_names = node_name.split('/') | |||||
| cur_node = self._watch_node | |||||
| status = 1 | |||||
| for scope_name in scope_names: | |||||
| cur_node = cur_node.get(scope_name) | |||||
| if cur_node is None: | |||||
| status = WatchNodeTree.NOT_WATCH | |||||
| break | |||||
| if cur_node.watch_status == WatchNodeTree.TOTAL_WATCH: | |||||
| status = WatchNodeTree.TOTAL_WATCH | |||||
| break | |||||
| if status == WatchNodeTree.TOTAL_WATCH and cur_node.node_name != node_name: | |||||
| self._watch_node.add_node(node_name, node_type, full_name) | |||||
| return status | |||||
| def get_watch_node(self, cur_watch_node, watch_node_list): | |||||
| """ | |||||
| Traverse the watch nodes and add total watched node list to `watch_node_list`. | |||||
| Args: | |||||
| cur_watch_node (WatchNodeTree): The current watch node. | |||||
| watch_node_list (list[WatchNodeTree]): The list of total watched node. | |||||
| """ | |||||
| if cur_watch_node.watch_status == WatchNodeTree.TOTAL_WATCH: | |||||
| watch_node_list.append(cur_watch_node) | |||||
| return | |||||
| for _, watch_node in cur_watch_node.get_children(): | |||||
| self.get_watch_node(watch_node, watch_node_list) | |||||
| def get_set_cmd(self): | |||||
| """Return the watchpoint in proto format.""" | |||||
| # get watch nodes. | |||||
| watch_nodes = [] | |||||
| self.get_watch_node(self._watch_node, watch_nodes) | |||||
| # construct SetCMD | |||||
| set_cmd = SetCMD() | |||||
| set_cmd.id = self._id | |||||
| set_cmd.delete = False | |||||
| set_cmd.watch_condition.condition = WATCHPOINT_CONDITION_MAPPING.get( | |||||
| self._condition.get('condition')) | |||||
| if self._condition.get('param'): | |||||
| # at most one param is provided | |||||
| set_cmd.watch_condition.value = self._condition.get('param') | |||||
| for watch_node in watch_nodes: | |||||
| event_node = set_cmd.watch_nodes.add() | |||||
| event_node.node_name = watch_node.full_name | |||||
| event_node.node_type = watch_node.node_type | |||||
| return set_cmd | |||||
| def get_watch_condition_info(self): | |||||
| """Get watch condition info.""" | |||||
| watchpoint_info = { | |||||
| 'id': self._id, | |||||
| 'watch_condition': self._condition | |||||
| } | |||||
| return watchpoint_info | |||||
| class WatchpointHit: | |||||
| """The watchpoint hit structure.""" | |||||
| def __init__(self, tensor_proto, watchpoint, node_name): | |||||
| self._node_name = node_name | |||||
| self._full_name = tensor_proto.node_name | |||||
| self._slot = tensor_proto.slot | |||||
| self._watchpoint = watchpoint | |||||
| @property | |||||
| def tensor_full_name(self): | |||||
| """The property of tensor_name.""" | |||||
| tensor_name = ':'.join([self._full_name, self._slot]) | |||||
| return tensor_name | |||||
| @property | |||||
| def tensor_name(self): | |||||
| """The property of tensor_name.""" | |||||
| return ':'.join([self._node_name, self._slot]) | |||||
| @property | |||||
| def watchpoint(self): | |||||
| """The property of watchpoint.""" | |||||
| watchpoint = self._watchpoint.get_watch_condition_info() | |||||
| return watchpoint | |||||
| def __eq__(self, other): | |||||
| """Define the equal condition.""" | |||||
| flag = self.tensor_full_name == other.tensor_full_name and self.watchpoint == other.watchpoint | |||||
| return flag | |||||
| @@ -0,0 +1,23 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Import the streams handlers.""" | |||||
| from .event_handler import EventHandler | |||||
| from .metadata_handler import MetadataHandler | |||||
| from .graph_handler import GraphHandler | |||||
| from .tensor_handler import TensorHandler | |||||
| from .watchpoint_handler import WatchpointHandler, WatchpointHitHandler | |||||
| __all__ = ['EventHandler', 'MetadataHandler', 'GraphHandler', 'TensorHandler', | |||||
| 'WatchpointHandler', 'WatchpointHitHandler'] | |||||
| @@ -0,0 +1,34 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Define the stream handler base.""" | |||||
| from abc import abstractmethod | |||||
| class StreamHandlerBase: | |||||
| """The stream handler base.""" | |||||
| @abstractmethod | |||||
| def put(self, value): | |||||
| """Abstract method of set data.""" | |||||
| return NotImplementedError | |||||
| @abstractmethod | |||||
| def get(self, filter_condition): | |||||
| """Abstract method of get data.""" | |||||
| return NotImplementedError | |||||
| def clean(self): | |||||
| """Clean cache.""" | |||||
| self.__init__() | |||||
| @@ -0,0 +1,159 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Define the message handler.""" | |||||
| import uuid | |||||
| from queue import Queue, Empty | |||||
| from threading import Lock | |||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | |||||
| from mindinsight.debugger.common.log import logger as log | |||||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | |||||
| class EventHandler(StreamHandlerBase): | |||||
| """Message Handler.""" | |||||
| max_limit = 1000 # the max number of items in cache | |||||
| def __init__(self): | |||||
| self._prev_flag = str(uuid.uuid4()) | |||||
| self._cur_flag = str(uuid.uuid4()) | |||||
| self._next_idx = 0 | |||||
| self._event_cache = [None] * self.max_limit | |||||
| self._pending_requests = {} | |||||
| self._lock = Lock() | |||||
| @property | |||||
| def next_pos(self): | |||||
| """The next pos to be updated in cache.""" | |||||
| return ':'.join([self._cur_flag, str(self._next_idx)]) | |||||
| def has_pos(self, pos): | |||||
| """Get the event according to pos.""" | |||||
| cur_flag, cur_idx = self._parse_pos(pos) | |||||
| event = self._event_cache[cur_idx] | |||||
| if event is not None: | |||||
| if not cur_flag or (cur_flag == self._cur_flag and cur_idx < self._next_idx) or \ | |||||
| (cur_flag == self._prev_flag and cur_idx >= self._next_idx): | |||||
| return event | |||||
| return None | |||||
| def clean(self): | |||||
| """Clean event cache.""" | |||||
| with self._lock: | |||||
| self._prev_flag = str(uuid.uuid4()) | |||||
| self._cur_flag = str(uuid.uuid4()) | |||||
| self._next_idx = 0 | |||||
| self._event_cache = [None] * self.max_limit | |||||
| value = {'metadata': {'pos': '0'}} | |||||
| self.clean_pending_requests(value) | |||||
| log.debug("Clean event cache.") | |||||
| def put(self, value): | |||||
| """ | |||||
| Put value into event_cache. | |||||
| Args: | |||||
| value (dict): The event to be put into cache. | |||||
| """ | |||||
| if not isinstance(value, dict): | |||||
| log.error("Dict type required when put event message.") | |||||
| raise DebuggerParamValueError("Dict type required when put event message.") | |||||
| with self._lock: | |||||
| log.debug("Put the %d-th message into queue. \n %d requests is waiting.", | |||||
| self._next_idx, len(self._pending_requests)) | |||||
| cur_pos = self._next_idx | |||||
| # update next pos | |||||
| self._next_idx += 1 | |||||
| if self._next_idx >= self.max_limit: | |||||
| self._next_idx = 0 | |||||
| self._prev_flag = self._cur_flag | |||||
| self._cur_flag = str(uuid.uuid4()) | |||||
| # set next pos | |||||
| if not value.get('metadata'): | |||||
| value['metadata'] = {} | |||||
| value['metadata']['pos'] = self.next_pos | |||||
| self._event_cache[cur_pos] = value | |||||
| # feed the value for pending requests | |||||
| self.clean_pending_requests(value) | |||||
| def clean_pending_requests(self, value): | |||||
| """Clean pending requests.""" | |||||
| for _, request in self._pending_requests.items(): | |||||
| request.put(value) | |||||
| self._pending_requests = {} | |||||
| def get(self, filter_condition=None): | |||||
| """ | |||||
| Get the pos-th value from event_cache according to filter_condition. | |||||
| Args: | |||||
| filter_condition (str): The index of event in cache. Default: None. | |||||
| Returns: | |||||
| object, the pos-th event. | |||||
| """ | |||||
| flag, idx = self._parse_pos(filter_condition) | |||||
| cur_id = str(uuid.uuid4()) | |||||
| with self._lock: | |||||
| # reset the pos after the cache is re-initialized. | |||||
| if not flag or flag not in [self._cur_flag, self._prev_flag]: | |||||
| idx = 0 | |||||
| # get event from cache immediately | |||||
| if idx != self._next_idx and self._event_cache[idx]: | |||||
| return self._event_cache[idx] | |||||
| # wait for the event | |||||
| cur_queue = Queue(maxsize=1) | |||||
| self._pending_requests[cur_id] = cur_queue | |||||
| # block until event has been received | |||||
| event = self._wait_for_event(cur_id, cur_queue, filter_condition) | |||||
| return event | |||||
| def _parse_pos(self, pos): | |||||
| """Get next pos according to input position.""" | |||||
| elements = pos.split(':') | |||||
| try: | |||||
| idx = int(elements[-1]) | |||||
| except ValueError: | |||||
| log.error("Invalid index. The index in pos should be digit but get pos:%s", pos) | |||||
| raise DebuggerParamValueError("Invalid pos.") | |||||
| if idx < 0 or idx >= self.max_limit: | |||||
| log.error("Invalid index. The index in pos should between [0, %d)", self.max_limit) | |||||
| raise DebuggerParamValueError(f"Invalid pos. {idx}") | |||||
| flag = elements[0] if len(elements) == 2 else '' | |||||
| return flag, idx | |||||
| def _wait_for_event(self, cur_id, cur_queue, pos): | |||||
| """Wait for the pos-th event.""" | |||||
| try: | |||||
| # set the timeout to 25 seconds which is less the the timeout limit from UI | |||||
| event = cur_queue.get(timeout=25) | |||||
| except Empty: | |||||
| event = None | |||||
| if event is None: | |||||
| with self._lock: | |||||
| if self._pending_requests.get(cur_id): | |||||
| self._pending_requests.pop(cur_id) | |||||
| log.debug("Clean timeout request. Left pending requests: %d", | |||||
| len(self._pending_requests)) | |||||
| event = {'metadata': {'pos': pos}} | |||||
| return event | |||||
| @@ -0,0 +1,314 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Define the graph stream handler.""" | |||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | |||||
| DebuggerNodeNotInGraphError, DebuggerGraphNotExistError | |||||
| from mindinsight.debugger.common.log import logger as log | |||||
| from mindinsight.debugger.stream_cache.debugger_graph import DebuggerGraph | |||||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | |||||
| class GraphHandler(StreamHandlerBase): | |||||
| """Metadata Handler.""" | |||||
| def __init__(self): | |||||
| self._graph_proto = None | |||||
| self._graph = None | |||||
| self._searched_node_list = [] | |||||
| self.bfs_order = [] | |||||
| @property | |||||
| def graph(self): | |||||
| """The property of graph.""" | |||||
| return self._graph_proto | |||||
| def put(self, value): | |||||
| """ | |||||
| Put value into graph cache. Called by grpc server. | |||||
| Args: | |||||
| value (GraphProto): The Graph proto message. | |||||
| """ | |||||
| self._graph_proto = value | |||||
| log.info("Put graph into cache.") | |||||
| # build graph | |||||
| graph = DebuggerGraph() | |||||
| graph.build_graph(value) | |||||
| self._graph = graph | |||||
| self.bfs_order = self._graph.get_bfs_order() | |||||
| def get(self, filter_condition=None): | |||||
| """ | |||||
| Get the graph of specific node. | |||||
| Args: | |||||
| filter_condition (dict): | |||||
| - name (str): The full debug node name. | |||||
| - single_node (bool): If True, return the graph from root | |||||
| to the specific node; else, return the sublayer of the | |||||
| graph. Default: False. | |||||
| Returns: | |||||
| dict, the metadata. | |||||
| """ | |||||
| try: | |||||
| self._graph_exists() | |||||
| except DebuggerGraphNotExistError: | |||||
| log.warning('The graph is empty. To view a graph, ' | |||||
| 'please start the training script first.') | |||||
| return {'graph': {}} | |||||
| if filter_condition is None: | |||||
| filter_condition = {} | |||||
| single_node = filter_condition.get('single_node', False) | |||||
| name = filter_condition.get('name') | |||||
| graph = {} | |||||
| if single_node is True: | |||||
| nodes = self.get_single_node(name) | |||||
| else: | |||||
| nodes = self.list_nodes(name) | |||||
| graph.update(nodes) | |||||
| return {'graph': graph} | |||||
| def get_tensor_history(self, node_name, depth=0): | |||||
| """ | |||||
| Get the tensor history of a specified node. | |||||
| Args: | |||||
| node_name (str): The debug name of the node. | |||||
| depth (int): The number of layers the user | |||||
| wants to trace. Default is 0. | |||||
| Returns: | |||||
| dict, basic tensor history, only including tensor name and tensor type and node type. | |||||
| """ | |||||
| self._graph_exists() | |||||
| if not self._graph.exist_node(node_name): | |||||
| raise DebuggerNodeNotInGraphError(node_name) | |||||
| tensor_history, cur_outputs_nums = self._graph.get_tensor_history( | |||||
| node_name, depth | |||||
| ) | |||||
| # add the tensor type for tensor history | |||||
| self._update_tensor_history(tensor_history[0:cur_outputs_nums], 'output') | |||||
| self._update_tensor_history(tensor_history[cur_outputs_nums:], 'input') | |||||
| log.debug("Get %d tensors in tensor history for node <%s>.", len(tensor_history), node_name) | |||||
| return {'tensor_history': tensor_history} | |||||
| @staticmethod | |||||
| def _update_tensor_history(tensor_history, tensor_type): | |||||
| """ | |||||
| Add tensor source type for tensor history. | |||||
| Args: | |||||
| tensor_history (list[dict]): Tensor history from Graph stream. Each element has two | |||||
| keys: `node_type` and `name`. `node_type` refers to the type of the node which | |||||
| the tensor come from. `name` refers to the tensor name. | |||||
| tensor_type (str): The source type of the tensor. `input` or `output`. | |||||
| """ | |||||
| for single_tensor_info in tensor_history: | |||||
| single_tensor_info['type'] = tensor_type | |||||
| def search_nodes(self, pattern): | |||||
| """ | |||||
| Search nodes by given pattern. | |||||
| Args: | |||||
| pattern (Union[str, None]): The pattern of the node to search, | |||||
| if None, return all node names. | |||||
| Returns: | |||||
| dict, the searched node. | |||||
| """ | |||||
| self._graph_exists() | |||||
| self._searched_node_list = self._graph.search_nodes_by_pattern(pattern) | |||||
| nodes = self._graph.get_nodes(self._searched_node_list) | |||||
| return {'nodes': nodes} | |||||
| def get_node_names(self, pattern=None): | |||||
| """Get graph nodes according to pattern.""" | |||||
| return self._graph.search_nodes_by_pattern(pattern) | |||||
| def get_searched_node_list(self): | |||||
| """Get searched node list.""" | |||||
| return self._searched_node_list | |||||
| def get_node_type(self, node_name): | |||||
| """ | |||||
| Get the type of the specified node. | |||||
| Args: | |||||
| node_name (str): The debug name of the node. | |||||
| Returns: | |||||
| A string of the node type, name_scope or leaf. | |||||
| """ | |||||
| self._graph_exists() | |||||
| node_type = self._graph.get_node_type(node_name) | |||||
| return node_type | |||||
| def get_full_name(self, node_name): | |||||
| """Get full name according to ui node name.""" | |||||
| full_name = self._graph.get_full_name_by_node_name(node_name) if node_name else '' | |||||
| return full_name | |||||
| def get_node_name_by_full_name(self, full_name): | |||||
| """Get UI node name by full name.""" | |||||
| if self._graph: | |||||
| node_name = self._graph.get_node_name_by_full_name(full_name) | |||||
| else: | |||||
| node_name = '' | |||||
| log.info("No graph received yet.") | |||||
| return node_name | |||||
| def list_nodes(self, scope): | |||||
| """ | |||||
| Get the nodes of every layer in graph. | |||||
| Args: | |||||
| scope (str): The name of a scope. | |||||
| Returns: | |||||
| TypedDict('Nodes', {'nodes': list[Node]}), format is {'nodes': [<Node object>]}. | |||||
| example: | |||||
| { | |||||
| "nodes" : [ | |||||
| { | |||||
| "attr" : | |||||
| { | |||||
| "index" : "i: 0\n" | |||||
| }, | |||||
| "input" : {}, | |||||
| "name" : "input_tensor", | |||||
| "output" : | |||||
| { | |||||
| "Default/TensorAdd-op17" : | |||||
| { | |||||
| "edge_type" : "data", | |||||
| "scope" : "name_scope", | |||||
| "shape" : [1, 16, 128, 128] | |||||
| } | |||||
| }, | |||||
| "output_i" : -1, | |||||
| "proxy_input" : {}, | |||||
| "proxy_output" : {}, | |||||
| "independent_layout" : False, | |||||
| "subnode_count" : 0, | |||||
| "type" : "Data" | |||||
| } | |||||
| ] | |||||
| } | |||||
| """ | |||||
| if scope and not self._graph.exist_node(scope): | |||||
| raise DebuggerNodeNotInGraphError(node_name=scope) | |||||
| nodes = self._graph.list_node_by_scope(scope=scope) | |||||
| return {'nodes': nodes} | |||||
| def get_node_by_bfs_order(self, node_name=None, ascend=True): | |||||
| """ | |||||
| Traverse the graph in order of breath-first search by given node. | |||||
| Args: | |||||
| node_name (str): The node name which will be regarded | |||||
| as the start node in graph. | |||||
| ascend (bool): If True, traverse the input nodes; | |||||
| If False, traverse the output nodes. Default is True. | |||||
| Returns: | |||||
| dict, including the searched node and its tensor value. | |||||
| """ | |||||
| self._graph_exists() | |||||
| bfs_order = self.bfs_order | |||||
| length = len(bfs_order) | |||||
| if not bfs_order: | |||||
| log.error('Cannot get the BFS order of the graph!') | |||||
| msg = 'Cannot get the BFS order of the graph!' | |||||
| raise DebuggerParamValueError(msg) | |||||
| if node_name is None: | |||||
| if ascend is False: | |||||
| next_node = bfs_order[-1] | |||||
| else: | |||||
| next_node = bfs_order[0] | |||||
| else: | |||||
| try: | |||||
| index = bfs_order.index(node_name) | |||||
| log.debug("The index of the node in BFS list is: %d", index) | |||||
| except ValueError as err: | |||||
| log.error('Cannot find the node: %s. Please check ' | |||||
| 'the node name: %s', node_name, err) | |||||
| msg = f'Cannot find the node: {node_name}. ' \ | |||||
| f'Please check the node name {err}.' | |||||
| raise DebuggerParamValueError(msg) | |||||
| next_node = self.get_next_node_in_bfs(index, length, ascend) | |||||
| return next_node | |||||
| def get_next_node_in_bfs(self, index, length, ascend): | |||||
| """ | |||||
| Get the next node in bfs order. | |||||
| Args: | |||||
| index (int): The current index. | |||||
| length (int): The number of all leaf nodes. | |||||
| ascend (bool): Whether get the node in ascend order or not. | |||||
| Returns: | |||||
| Union[None, dict], the next node object in dict type or None. | |||||
| """ | |||||
| next_node = None | |||||
| if 0 <= index < length: | |||||
| if ascend is True and index < length - 1: | |||||
| next_node = self.bfs_order[index + 1] | |||||
| elif ascend is False and index > 0: | |||||
| next_node = self.bfs_order[index - 1] | |||||
| return next_node | |||||
| def get_single_node(self, name): | |||||
| """ | |||||
| Search node, and return every layer nodes until this node. | |||||
| Args: | |||||
| name (str): The name of node. | |||||
| Returns: | |||||
| dict, every layer nodes until this node. | |||||
| """ | |||||
| nodes = self._graph.search_single_node(name) | |||||
| return nodes | |||||
| def _graph_exists(self): | |||||
| """ | |||||
| Check if the graph has been loaded in the debugger cache. | |||||
| Raises: | |||||
| DebuggerGraphNotExistError: If the graph does not exist. | |||||
| """ | |||||
| if self._graph is None: | |||||
| log.error('The graph does not exist. Please start the ' | |||||
| 'training script and try again.') | |||||
| raise DebuggerGraphNotExistError | |||||
| @@ -0,0 +1,131 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Define the metadata stream handler.""" | |||||
| from mindinsight.debugger.common.log import logger as log | |||||
| from mindinsight.debugger.common.utils import ServerStatus | |||||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | |||||
| class MetadataHandler(StreamHandlerBase): | |||||
| """Metadata Handler.""" | |||||
| def __init__(self): | |||||
| self._state = ServerStatus.PENDING | |||||
| self._device_name = "" | |||||
| self._step = 0 | |||||
| self._client_ip = "" | |||||
| self._cur_node_name = "" | |||||
| self._cur_full_name = "" | |||||
| self._backend = "" | |||||
| @property | |||||
| def device_name(self): | |||||
| """The property of device name.""" | |||||
| return self._device_name | |||||
| @property | |||||
| def step(self): | |||||
| """The property of current step.""" | |||||
| return self._step | |||||
| @property | |||||
| def node_name(self): | |||||
| """The property of current node name.""" | |||||
| return self._cur_node_name | |||||
| @node_name.setter | |||||
| def node_name(self, node_name): | |||||
| """The property of current node name.""" | |||||
| self._cur_node_name = node_name | |||||
| @property | |||||
| def full_name(self): | |||||
| """The property of current node name.""" | |||||
| return self._cur_full_name | |||||
| @property | |||||
| def backend(self): | |||||
| """The property of current backend.""" | |||||
| return self._backend | |||||
| @property | |||||
| def state(self): | |||||
| """The property of state.""" | |||||
| return self._state.value | |||||
| @state.setter | |||||
| def state(self, value): | |||||
| """ | |||||
| Set the property of state. | |||||
| Args: | |||||
| value (str): The new state. | |||||
| """ | |||||
| self._state = ServerStatus(value) | |||||
| @property | |||||
| def client_ip(self): | |||||
| """The property of client ip.""" | |||||
| return self._client_ip | |||||
| @client_ip.setter | |||||
| def client_ip(self, value): | |||||
| """ | |||||
| Set the property of client ip. | |||||
| Args: | |||||
| value (str): The new ip. | |||||
| """ | |||||
| self._client_ip = str(value) | |||||
| def put(self, value): | |||||
| """ | |||||
| Put value into metadata cache. Called by grpc server. | |||||
| Args: | |||||
| value (MetadataProto): The Metadata proto message. | |||||
| """ | |||||
| self._device_name = value.device_name.split(':')[0] | |||||
| self._step = value.cur_step | |||||
| self._cur_full_name = value.cur_node | |||||
| self._backend = value.backend if value.backend else "Ascend" | |||||
| log.debug("Put metadata into cache at the %d-th step.", self._step) | |||||
| def get(self, filter_condition=None): | |||||
| """ | |||||
| Get updated value. Called by main server. | |||||
| Args: | |||||
| filter_condition (str): The filter property. | |||||
| Returns: | |||||
| dict, the metadata. | |||||
| """ | |||||
| metadata = {} | |||||
| if filter_condition is None: | |||||
| metadata = { | |||||
| 'state': self.state, | |||||
| 'step': self.step, | |||||
| 'device_name': self.device_name, | |||||
| 'pos': '0', | |||||
| 'ip': self.client_ip, | |||||
| 'node_name': self.node_name, | |||||
| 'backend': self.backend | |||||
| } | |||||
| else: | |||||
| metadata[filter_condition] = getattr(self, filter_condition) if \ | |||||
| hasattr(self, filter_condition) else '' | |||||
| return {'metadata': metadata} | |||||
| @@ -0,0 +1,298 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Define the tensor stream handler.""" | |||||
| import numpy as np | |||||
| from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum | |||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | |||||
| from mindinsight.debugger.common.log import logger as log | |||||
| from mindinsight.debugger.proto.ms_graph_pb2 import DataType | |||||
| from mindinsight.debugger.stream_cache.tensor import OpTensor, ConstTensor | |||||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | |||||
| from mindinsight.utils.tensor import TensorUtils | |||||
| class TensorHandler(StreamHandlerBase): | |||||
| """Metadata Handler.""" | |||||
| def __init__(self): | |||||
| self._const_vals = {} | |||||
| self._tensors = {} | |||||
| self._cur_step = 0 | |||||
| def put(self, value): | |||||
| """ | |||||
| Put value into tensor cache. Called by grpc server. | |||||
| Args: | |||||
| value (dict): The Tensor proto message. | |||||
| - step (int): The current step of tensor. | |||||
| - tensor_protos (list[TensorProto]): The tensor proto. | |||||
| """ | |||||
| tensor_protos = value.get('tensor_protos') | |||||
| merged_tensor = self._get_merged_tensor(tensor_protos) | |||||
| step = value.get('step', 0) | |||||
| if merged_tensor.iter and step > 0: | |||||
| log.debug("Received previous tensor.") | |||||
| step -= 1 | |||||
| tensor = OpTensor(merged_tensor, step) | |||||
| self._put_tensor_into_cache(tensor, step) | |||||
| log.debug("Put tensor %s of step: %d, into cache", tensor.name, step) | |||||
| @staticmethod | |||||
| def _get_merged_tensor(tensor_protos): | |||||
| """ | |||||
| Merged list of parsed tensor value into one. | |||||
| Args: | |||||
| tensor_protos (list[TensorProto]): List of tensor proto. | |||||
| Returns: | |||||
| TensorProto, merged tensor proto. | |||||
| """ | |||||
| merged_tensor = tensor_protos[-1] | |||||
| if len(tensor_protos) > 1: | |||||
| tensor_value = bytes() | |||||
| for tensor_proto in tensor_protos: | |||||
| if not tensor_proto.tensor_content: | |||||
| log.warning("Doesn't find tensor value for %s:%s", | |||||
| tensor_proto.node_name, tensor_proto.slot) | |||||
| break | |||||
| tensor_value += tensor_proto.tensor_content | |||||
| merged_tensor.tensor_content = tensor_value | |||||
| log.debug("Merge multi tensor values into one.") | |||||
| return merged_tensor | |||||
| def _put_tensor_into_cache(self, tensor, step): | |||||
| """ | |||||
| Put tensor into cache. | |||||
| Args: | |||||
| tensor (OpTensor): The tensor value. | |||||
| """ | |||||
| cache_tensor = self._tensors.get(tensor.name) | |||||
| if cache_tensor is None: | |||||
| cache_tensor = {} | |||||
| self._tensors[tensor.name] = cache_tensor | |||||
| cache_tensor[step] = tensor | |||||
| def put_const_vals(self, const_vals): | |||||
| """ | |||||
| Put const value into tensor cache. | |||||
| Args: | |||||
| const_vals (list[NamedValueProto]): List of const values. | |||||
| """ | |||||
| for const_val in const_vals: | |||||
| if not (const_val.value and const_val.key): | |||||
| continue | |||||
| if DataType.Name(const_val.value.dtype) == "DT_TENSOR": | |||||
| tensor_proto = const_val.value.tensor_val | |||||
| tensor_proto.node_name = const_val.key | |||||
| tensor_proto.slot = '0' | |||||
| const_tensor = OpTensor(tensor_proto) | |||||
| else: | |||||
| const_tensor = ConstTensor(const_val) | |||||
| self._const_vals[const_tensor.name] = const_tensor | |||||
| def get(self, filter_condition=None): | |||||
| """ | |||||
| Get full tensor value. | |||||
| Args: | |||||
| filter_condition (dict): Filter condition. | |||||
| - name (str): The name of tensor. | |||||
| - node_type (str): The type of the node. | |||||
| Returns: | |||||
| dict, the tensor_value. | |||||
| """ | |||||
| name = filter_condition.get('name') | |||||
| node_type = filter_condition.get('node_type') | |||||
| shape = filter_condition.get('shape') | |||||
| tensor = self._get_tensor(name, node_type) | |||||
| if not tensor: | |||||
| log.error("No tensor named %s", name) | |||||
| raise DebuggerParamValueError("No tensor named {}".format(name)) | |||||
| tensor_info = tensor.get_full_info(shape) | |||||
| self._update_has_prev_step_field(tensor_info, name, node_type) | |||||
| return {'tensor_value': tensor_info} | |||||
| def _get_tensor(self, tensor_name, node_type=None, step=None): | |||||
| """ | |||||
| Get tensor according to tensor name and node_type. | |||||
| Args: | |||||
| tensor_name (str): Tensor name, format like `node_name:slot`. | |||||
| node_type (str): Node type. | |||||
| step (int): The step of tensor info. Default: None. Noe | |||||
| Returns: | |||||
| Union[OPTensor, ConstTensor], the tensor object. | |||||
| """ | |||||
| if step is None: | |||||
| step = self._cur_step | |||||
| tensor = self._tensors.get(tensor_name, {}).get(step) | |||||
| if not tensor and node_type == NodeTypeEnum.CONST.value: | |||||
| const_name = tensor_name.rsplit('/', 1)[-1] | |||||
| tensor = self._const_vals.get(const_name) | |||||
| self._tensors[tensor_name] = {step: tensor} | |||||
| return tensor | |||||
| def _get_basic_info(self, tensor_name, node_type=None): | |||||
| """Get the latest basic tensor info by tensor name.""" | |||||
| tensor = self._get_tensor(tensor_name, node_type) | |||||
| if tensor: | |||||
| return tensor.get_basic_info() | |||||
| return None | |||||
| def get_tensor_history(self, tensor_names): | |||||
| """Get tensor history for tensor names.""" | |||||
| # only used by grpc server, could be remove later | |||||
| tensor_infos = [] | |||||
| for tensor_name in tensor_names: | |||||
| tensor_info = self._get_basic_info(tensor_name) | |||||
| tensor_infos.append(tensor_info) | |||||
| return {'tensor_history': tensor_infos} | |||||
| def update_tensor_history(self, tensor_history): | |||||
| """ | |||||
| Add tensor basic info in tensor_history. | |||||
| Args: | |||||
| tensor_history (dict): Tensor history, including a list of tensor name and type. | |||||
| Returns: | |||||
| list[dict], the list of tensor basic info cache. | |||||
| """ | |||||
| missed_tensors = [] | |||||
| for tensor_info in tensor_history.get('tensor_history'): | |||||
| tensor_name = tensor_info.get('full_name') | |||||
| node_type = tensor_info.get('node_type') | |||||
| basic_info = self._get_basic_info(tensor_name, node_type) | |||||
| flag = self._update_has_prev_step_field(basic_info, tensor_name, node_type) | |||||
| if flag is False: | |||||
| missed_tensor = tensor_info.copy() | |||||
| missed_tensor['iter'] = 'prev' | |||||
| missed_tensors.append(missed_tensor) | |||||
| log.debug("Add previous view cmd for %s", tensor_name) | |||||
| # add `has_prev_step` field to tensor basic info. | |||||
| if basic_info: | |||||
| tensor_info.update(basic_info) | |||||
| else: | |||||
| missed_tensors.append(tensor_info) | |||||
| log.debug("Add view cmd for %s", tensor_name) | |||||
| return missed_tensors | |||||
| def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type): | |||||
| """Update has_prev_step field in tensor info.""" | |||||
| flag = None | |||||
| if node_type == NodeTypeEnum.PARAMETER.value: | |||||
| flag = self._has_prev_tensor_value(tensor_name) | |||||
| if flag and tensor_info: | |||||
| tensor_info['has_prev_step'] = True | |||||
| return flag | |||||
| def _has_prev_tensor_value(self, tensor_name): | |||||
| """ | |||||
| Check if the tensor has valid value of previous step. | |||||
| Args: | |||||
| tensor_name (str): Tensor name. | |||||
| Returns: | |||||
| bool, whether the tensor has valid tensor value. | |||||
| """ | |||||
| flag = None | |||||
| # check if the tensor has previous step value. | |||||
| prev_step = self._cur_step - 1 | |||||
| if prev_step < 0: | |||||
| return flag | |||||
| tensor = self._get_tensor(tensor_name, step=prev_step) | |||||
| flag = bool(tensor and tensor.value) | |||||
| return flag | |||||
| def get_tensor_value_by_name(self, tensor_name, prev=False): | |||||
| """Get tensor value by name in numpy type.""" | |||||
| cur_step = self._cur_step | |||||
| step = cur_step - 1 if prev else cur_step | |||||
| if step < 0: | |||||
| log.warning("%d step has no previous value for tensor: %s", cur_step, tensor_name) | |||||
| return None | |||||
| tensor = self._get_tensor(tensor_name, step=step) | |||||
| return tensor | |||||
| def clean_tensors(self, cur_step): | |||||
| """Clean the tensor cache.""" | |||||
| self._cur_step = cur_step | |||||
| expired_tensor = [] | |||||
| for tensor_name, tensor in self._tensors.items(): | |||||
| expired_step = [step for step in tensor.keys() if step <= cur_step - 2] | |||||
| for step in expired_step: | |||||
| tensor.pop(step) | |||||
| if not tensor: | |||||
| expired_tensor.append(tensor_name) | |||||
| for tensor_name in expired_tensor: | |||||
| self._tensors.pop(tensor_name) | |||||
| self._tensors = {} | |||||
| def get_tensors_diff(self, tensor_name, shape, tolerance=0): | |||||
| """ | |||||
| Get tensor comparisons data for given name, detail, shape and tolerance. | |||||
| Args: | |||||
| tensor_name (str): The name of tensor for cache. | |||||
| shape (tuple): Specify concrete dimensions of shape. | |||||
| tolerance (str): Specify tolerance of difference between current step tensor and previous | |||||
| step tensor. Default value is 0. Its is a percentage. The boundary value is equal to | |||||
| max(abs(min),abs(max)) * tolerance. The function of min and max is being used to | |||||
| calculate the min value and max value of the result of the current step tensor subtract | |||||
| the previous step tensor. If the absolute value of result is less than or equal to | |||||
| boundary value, the result will set to be zero. | |||||
| Raises: | |||||
| DebuggerParamValueError, If get current step node and previous step node failed. | |||||
| Returns: | |||||
| dict, the retrieved data. | |||||
| """ | |||||
| curr_tensor = self.get_tensor_value_by_name(tensor_name) | |||||
| prev_tensor = self.get_tensor_value_by_name(tensor_name, prev=True) | |||||
| if not (curr_tensor and prev_tensor): | |||||
| log.error("Get current step and previous step for this tensor name %s failed.", tensor_name) | |||||
| raise DebuggerParamValueError(f"Get current step and previous step for this tensor name " | |||||
| f"{tensor_name} failed.") | |||||
| curr_tensor_slice = curr_tensor.get_tensor_value_by_shape(shape) | |||||
| prev_tensor_slice = prev_tensor.get_tensor_value_by_shape(shape) | |||||
| tensor_info = curr_tensor.get_basic_info() | |||||
| if isinstance(curr_tensor_slice, np.ndarray) and isinstance(prev_tensor_slice, np.ndarray): | |||||
| diff_tensor = TensorUtils.calc_diff_between_two_tensor(curr_tensor_slice, prev_tensor_slice, tolerance) | |||||
| result = np.stack([prev_tensor_slice, curr_tensor_slice, diff_tensor], axis=-1) | |||||
| tensor_info['diff'] = result.tolist() | |||||
| stats = TensorUtils.get_statistics_from_tensor(diff_tensor) | |||||
| tensor_info['statistics'] = TensorUtils.get_statistics_dict(stats) | |||||
| del tensor_info['has_prev_step'] | |||||
| del tensor_info['value'] | |||||
| reply = {'tensor_value': tensor_info} | |||||
| return reply | |||||
| @@ -0,0 +1,333 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Define the watchpoint stream handler.""" | |||||
| import numpy as np | |||||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ | |||||
| DebuggerParamTypeError | |||||
| from mindinsight.debugger.common.log import logger as log | |||||
| from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD | |||||
| from mindinsight.debugger.stream_cache.watchpoint import Watchpoint, WatchpointHit, \ | |||||
| WATCHPOINT_CONDITION_MAPPING | |||||
| from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase | |||||
| class WatchpointHandler(StreamHandlerBase): | |||||
| """watchpoint Handler.""" | |||||
| def __init__(self): | |||||
| self._watchpoints = {} | |||||
| self._deleted_watchpoints = [] | |||||
| self._updated_watchpoints = {} | |||||
| self._latest_id = 0 | |||||
| def put(self, value): | |||||
| """ | |||||
| Put Watchpoint into watchpoint handler. | |||||
| Args: | |||||
| value (Watchpoint): The name of nodes that have been chosen. | |||||
| """ | |||||
| new_id = value.watchpoint_id | |||||
| self._watchpoints[new_id] = value | |||||
| self._updated_watchpoints[new_id] = value | |||||
| self._latest_id = new_id | |||||
| log.debug("Put watchpoint %d into cache.", new_id) | |||||
| def sync_set_cmd(self): | |||||
| """Clean temp watchpoints.""" | |||||
| self._deleted_watchpoints = [] | |||||
| self._updated_watchpoints = {} | |||||
| def get_watchpoint_by_id(self, watchpoint_id): | |||||
| """Get watchpoint by watchpoint id.""" | |||||
| watchpoint = self._watchpoints.get(watchpoint_id) | |||||
| if not watchpoint: | |||||
| log.error("Invalid watchpoint id %d", watchpoint_id) | |||||
| raise DebuggerParamValueError("Invalid watchpoint id {}".format(watchpoint_id)) | |||||
| return watchpoint | |||||
| def get(self, filter_condition=False): | |||||
| """ | |||||
| Get the watchpoints. | |||||
| Args: | |||||
| filter_condition (bool): If True, get all watchpoints without nodes. If False, | |||||
| get updated watchpoints in SetCMD proto format. Default: False. | |||||
| Returns: | |||||
| dict, the watchpoints. | |||||
| """ | |||||
| reply = [] | |||||
| if not filter_condition: | |||||
| # get watch condition list | |||||
| for _, watchpoint in self._watchpoints.items(): | |||||
| watchpoint_info = watchpoint.get_watch_condition_info() | |||||
| reply.append(watchpoint_info) | |||||
| else: | |||||
| # get updated watchpoint list | |||||
| for _, watchpoint in self._updated_watchpoints.items(): | |||||
| set_cmd = watchpoint.get_set_cmd() | |||||
| reply.append(set_cmd) | |||||
| reply.extend(self._deleted_watchpoints) | |||||
| log.debug("get the watch points with filter_condition:%s", filter_condition) | |||||
| return {'watch_points': reply} | |||||
| def set_watch_nodes(self, graph, watch_point_id): | |||||
| """ | |||||
| set watch nodes for graph. | |||||
| Args: | |||||
| graph (dict): The graph with list of nodes. | |||||
| watch_point_id (int): The id of watchpoint. | |||||
| """ | |||||
| if not (watch_point_id and graph): | |||||
| return | |||||
| self._validate_watchpoint_id(watch_point_id) | |||||
| log.debug("add watch flags") | |||||
| watchpoint = self._watchpoints.get(watch_point_id) | |||||
| self._set_watch_status_recursively(graph, watchpoint) | |||||
| def _set_watch_status_recursively(self, graph, watchpoint): | |||||
| """Set watch status to graph.""" | |||||
| if not isinstance(graph, dict): | |||||
| log.warning("The graph is not dict.") | |||||
| return | |||||
| if graph.get('children'): | |||||
| self._set_watch_status_recursively(graph.get('children'), watchpoint) | |||||
| for node in graph.get('nodes', []): | |||||
| if not isinstance(node, dict): | |||||
| log.warning("The node is not dict.") | |||||
| return | |||||
| node_name = node.get('name') | |||||
| if not node_name: | |||||
| continue | |||||
| flag = watchpoint.get_node_status(node_name, node.get('type'), node.get('full_name')) | |||||
| node['watched'] = flag | |||||
| if node.get('nodes'): | |||||
| self._set_watch_status_recursively(node, watchpoint) | |||||
| def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None): | |||||
| """ | |||||
| Create watchpoint. | |||||
| Args: | |||||
| watch_condition (dict): The watch condition. | |||||
| - condition (str): Accept `INF` or `NAN`. | |||||
| - param (list[float]): Not defined yet. | |||||
| watch_nodes (list[NodeBasicInfo]): The list of node basic info. | |||||
| watch_point_id (int): The id of watchpoint. | |||||
| Returns: | |||||
| int, the new id of watchpoint. | |||||
| """ | |||||
| validate_watch_condition(watch_condition) | |||||
| new_id = self._latest_id + 1 | |||||
| watchpoint = Watchpoint(new_id, watch_condition) | |||||
| if watch_nodes: | |||||
| watchpoint.add_nodes(watch_nodes) | |||||
| elif watch_point_id: | |||||
| self._validate_watchpoint_id(watch_point_id) | |||||
| watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id)) | |||||
| self.put(watchpoint) | |||||
| return new_id | |||||
| def update_watchpoint(self, watch_point_id, watch_nodes, watched=False): | |||||
| """ | |||||
| Update watchpoint. | |||||
| Args: | |||||
| watch_point_id (int): The id of watchpoint. | |||||
| watch_nodes (list[str]): The list of node names. | |||||
| watched (bool): The update operator on nodes. If False, remove nodes from watch nodes. | |||||
| If True, add nodes to watch nodes. Default: False. | |||||
| Returns: | |||||
| dict, empty response. | |||||
| """ | |||||
| self._validate_watchpoint_id(watch_point_id) | |||||
| watchpoint = self._watchpoints.get(watch_point_id) | |||||
| if watched: | |||||
| watchpoint.add_nodes(watch_nodes) | |||||
| else: | |||||
| watchpoint.remove_nodes(watch_nodes) | |||||
| self._updated_watchpoints[watch_point_id] = watchpoint | |||||
| log.debug("Update watchpoint %d in cache.", watch_point_id) | |||||
| def delete_watchpoint(self, watch_point_id): | |||||
| """ | |||||
| Delete watchpoint. | |||||
| Args: | |||||
| watch_point_id (int): The id of watchpoint. | |||||
| Returns: | |||||
| dict, empty response. | |||||
| """ | |||||
| self._validate_watchpoint_id(watch_point_id) | |||||
| self._watchpoints.pop(watch_point_id) | |||||
| set_cmd = SetCMD() | |||||
| set_cmd.id = watch_point_id | |||||
| set_cmd.delete = True | |||||
| self._deleted_watchpoints.append(set_cmd) | |||||
| log.debug("Delete watchpoint %d in cache.", watch_point_id) | |||||
| def _validate_watchpoint_id(self, watch_point_id): | |||||
| """Validate watchpoint id.""" | |||||
| if watch_point_id and watch_point_id not in self._watchpoints: | |||||
| log.error("Invalid watchpoint id: %d.", watch_point_id) | |||||
| raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id)) | |||||
| class WatchpointHitHandler(StreamHandlerBase): | |||||
| """Watchpoint hit handler.""" | |||||
| def __init__(self): | |||||
| self._hits = {} | |||||
| def put(self, value): | |||||
| """ | |||||
| Put value into watchpoint hit cache. Called by grpc server. | |||||
| Args: | |||||
| value (dict): The watchpoint hit info. | |||||
| - tensor_proto (TensorProto): The message about hit tensor. | |||||
| - watchpoint (Watchpoint): The Watchpoint that a node hit. | |||||
| """ | |||||
| watchpoint_hit = WatchpointHit( | |||||
| tensor_proto=value.get('tensor_proto'), | |||||
| watchpoint=value.get('watchpoint'), | |||||
| node_name=value.get('node_name') | |||||
| ) | |||||
| node_name = value.get('node_name') | |||||
| hit_tensors = self._hits.get(node_name) | |||||
| if hit_tensors is None: | |||||
| hit_tensors = [] | |||||
| self._hits[node_name] = hit_tensors | |||||
| if watchpoint_hit not in hit_tensors: | |||||
| hit_tensors.append(watchpoint_hit) | |||||
| def get(self, filter_condition=None): | |||||
| """ | |||||
| Get watchpoint hit list. | |||||
| Args: | |||||
| filter_condition (str): Get the watchpoint hit according to specifiled node name. | |||||
| If not given, get all watchpoint hits. Default: None. | |||||
| Returns: | |||||
| dict, the watchpoint hit list. | |||||
| """ | |||||
| if filter_condition is None: | |||||
| log.debug("Get all watchpoint hit list.") | |||||
| reply = self.get_watchpoint_hits() | |||||
| else: | |||||
| log.debug("Get the watchpoint for node: <%s>.", filter_condition) | |||||
| reply = self._hits.get(filter_condition) | |||||
| return reply | |||||
| def get_watchpoint_hits(self): | |||||
| """Return the list of watchpoint hits.""" | |||||
| watch_point_hits = [] | |||||
| for node_name, watchpoint_hits in self._hits.items(): | |||||
| watch_points = [watchpoint_hit.watchpoint for watchpoint_hit in watchpoint_hits] | |||||
| watch_point_hits.append({ | |||||
| 'node_name': node_name, | |||||
| 'watch_points': watch_points | |||||
| }) | |||||
| return {'watch_point_hits': watch_point_hits} | |||||
| def _is_tensor_hit(self, tensor_name): | |||||
| """Check if the tensor is record in hit cache.""" | |||||
| node_name = tensor_name.split(':')[0] | |||||
| watchpoint_hits = self.get(node_name) | |||||
| if watchpoint_hits is None: | |||||
| return False | |||||
| for watchpoint_hit in watchpoint_hits: | |||||
| if tensor_name == watchpoint_hit.tensor_name: | |||||
| return True | |||||
| return False | |||||
| def update_tensor_history(self, tensor_history): | |||||
| """ | |||||
| Add hit flag to tensor history. | |||||
| Args: | |||||
| tensor_history (dict): The tensor history. | |||||
| """ | |||||
| if not self._hits: | |||||
| return | |||||
| # add hit tensor names to `tensor_names` | |||||
| for tensor_info in tensor_history.get('tensor_history'): | |||||
| tensor_name = tensor_info['full_name'] | |||||
| hit_flag = self._is_tensor_hit(tensor_name) | |||||
| tensor_info['is_hit'] = hit_flag | |||||
| def validate_watch_condition(watch_condition): | |||||
| """Validate watch condition.""" | |||||
| if not isinstance(watch_condition, dict): | |||||
| log.error("<watch_condition> should be dict. %s received.", watch_condition) | |||||
| raise DebuggerParamTypeError("<watch_condition> should be dict.") | |||||
| # validate condition | |||||
| condition = watch_condition.get('condition') | |||||
| if condition not in WATCHPOINT_CONDITION_MAPPING.keys(): | |||||
| log.error("Invalid watch condition. Acceptable values are <%s>.", | |||||
| str(WATCHPOINT_CONDITION_MAPPING.keys())) | |||||
| raise DebuggerParamValueError("Invalid watch condition value.") | |||||
| # validate param | |||||
| validate_watch_condition_params(watch_condition) | |||||
| def validate_watch_condition_params(watch_condition): | |||||
| """ | |||||
| Validate watch condition parameters. | |||||
| Args: | |||||
| watch_condition (dict): Watch condition. | |||||
| - condition (str): Condition type. Should be in WATCHPOINT_CONDITION_MAPPING. | |||||
| - param (list): Condition value. Should be given for comparison condition. The value will | |||||
| be translated to np.float32. | |||||
| """ | |||||
| condition = watch_condition.get('condition') | |||||
| param = watch_condition.get('param') | |||||
| if condition in ['NAN', 'INF', 'OVERFLOW']: | |||||
| if param: | |||||
| log.error("No param is expected for %s condition.", condition) | |||||
| raise DebuggerParamValueError("No param is expected.") | |||||
| else: | |||||
| if not isinstance(param, (float, int)): | |||||
| log.error("Number param should be given for condition <%s>.", | |||||
| condition) | |||||
| raise DebuggerParamValueError("Number param should be given.") | |||||
| if np.isinf(np.float32(param)): | |||||
| log.error("Condition param should be float32.") | |||||
| raise DebuggerParamValueError("The value of condition param should be within float32.") | |||||
| @@ -14,19 +14,28 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Start mindinsight service.""" | """Start mindinsight service.""" | ||||
| import argparse | |||||
| import os | import os | ||||
| import sys | |||||
| import re | import re | ||||
| import argparse | |||||
| import sys | |||||
| from importlib import import_module | from importlib import import_module | ||||
| import psutil | import psutil | ||||
| from mindinsight.conf import settings | from mindinsight.conf import settings | ||||
| from mindinsight.utils.command import BaseCommand | from mindinsight.utils.command import BaseCommand | ||||
| from mindinsight.utils.exceptions import PortNotAvailableError | |||||
| from mindinsight.utils.hook import HookUtils | from mindinsight.utils.hook import HookUtils | ||||
| from mindinsight.utils.hook import init | from mindinsight.utils.hook import init | ||||
| from mindinsight.utils.exceptions import PortNotAvailableError | |||||
| def str2bool(string): | |||||
| """Convert str to bool""" | |||||
| if string.lower() == 'false': | |||||
| return False | |||||
| if string.lower() == 'true': | |||||
| return True | |||||
| raise ValueError | |||||
| class ConfigAction(argparse.Action): | class ConfigAction(argparse.Action): | ||||
| @@ -146,6 +155,23 @@ class UrlPathPrefixAction(argparse.Action): | |||||
| setattr(namespace, self.dest, prefix) | setattr(namespace, self.dest, prefix) | ||||
| class EnableDebuggerAction(argparse.Action): | |||||
| """SSL certificate action class definition.""" | |||||
| def __call__(self, parser, namespace, values, option_string=None): | |||||
| """ | |||||
| Inherited __call__ method from argparse.Action. | |||||
| Args: | |||||
| parser (ArgumentParser): Passed-in argument parser. | |||||
| namespace (Namespace): Namespace object to hold arguments. | |||||
| values (object): Argument values with type depending on argument definition. | |||||
| option_string (str): Optional string for specific argument name. Default: None. | |||||
| """ | |||||
| enable_debugger = values | |||||
| setattr(namespace, self.dest, enable_debugger) | |||||
| class Command(BaseCommand): | class Command(BaseCommand): | ||||
| """ | """ | ||||
| Start mindinsight service. | Start mindinsight service. | ||||
| @@ -186,6 +212,14 @@ class Command(BaseCommand): | |||||
| Custom port ranging from %s to %s. Default value is %s. | Custom port ranging from %s to %s. Default value is %s. | ||||
| """ % (PortAction.MIN_PORT, PortAction.MAX_PORT, settings.PORT)) | """ % (PortAction.MIN_PORT, PortAction.MAX_PORT, settings.PORT)) | ||||
| parser.add_argument( | |||||
| '--debugger_port', | |||||
| type=int, | |||||
| action=PortAction, | |||||
| help=""" | |||||
| Debugger port ranging from %s to %s. Default value is %s. | |||||
| """ % (PortAction.MIN_PORT, PortAction.MAX_PORT, settings.DEBUGGER_PORT)) | |||||
| parser.add_argument( | parser.add_argument( | ||||
| '--url-path-prefix', | '--url-path-prefix', | ||||
| type=str, | type=str, | ||||
| @@ -197,6 +231,14 @@ class Command(BaseCommand): | |||||
| dot or double dots. Default value is ''. | dot or double dots. Default value is ''. | ||||
| """) | """) | ||||
| parser.add_argument( | |||||
| '--enable_debugger', | |||||
| type=str2bool, | |||||
| action=EnableDebuggerAction, | |||||
| default=False, | |||||
| help=""" | |||||
| Enable debugger or not. | |||||
| Dfault is False.""") | |||||
| for hook in HookUtils.instance().hooks(): | for hook in HookUtils.instance().hooks(): | ||||
| hook.register_startup_arguments(parser) | hook.register_startup_arguments(parser) | ||||
| @@ -33,6 +33,7 @@ class MindInsightModules(Enum): | |||||
| SCRIPTCONVERTER = 7 | SCRIPTCONVERTER = 7 | ||||
| WIZARD = 9 | WIZARD = 9 | ||||
| OPTIMIZER = 10 | OPTIMIZER = 10 | ||||
| DEBUGGER = 11 | |||||
| class GeneralErrors(Enum): | class GeneralErrors(Enum): | ||||
| @@ -56,6 +57,10 @@ class LineageMgrErrors(Enum): | |||||
| """Enum definition for lineage errors.""" | """Enum definition for lineage errors.""" | ||||
| class DebuggerErrors(Enum): | |||||
| """Enum definition for debugger errors.""" | |||||
| class DataVisualErrors(Enum): | class DataVisualErrors(Enum): | ||||
| """Enum definition for datavisual errors.""" | """Enum definition for datavisual errors.""" | ||||
| RESTFUL_API_NOT_EXIST = 1 | RESTFUL_API_NOT_EXIST = 1 | ||||
| @@ -0,0 +1,298 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Tensor utils.""" | |||||
| import numpy as np | |||||
| from mindinsight.datavisual.utils.tools import to_int | |||||
| from mindinsight.utils.exceptions import ParamValueError | |||||
| from mindinsight.utils.exceptions import ParamTypeError | |||||
| from mindinsight.utils.log import utils_logger as logger | |||||
| F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max | |||||
| class Statistics: | |||||
| """Statistics data class. | |||||
| Args: | |||||
| max_value (float): max value of tensor data. | |||||
| min_value (float): min value of tensor data. | |||||
| avg_value (float): avg value of tensor data. | |||||
| count (int): total count of tensor data. | |||||
| nan_count (int): count of NAN. | |||||
| neg_inf_count (int): count of negative INF. | |||||
| pos_inf_count (int): count of positive INF. | |||||
| """ | |||||
| def __init__(self, max_value=0, min_value=0, avg_value=0, | |||||
| count=0, nan_count=0, neg_inf_count=0, pos_inf_count=0): | |||||
| self._max = max_value | |||||
| self._min = min_value | |||||
| self._avg = avg_value | |||||
| self._count = count | |||||
| self._nan_count = nan_count | |||||
| self._neg_inf_count = neg_inf_count | |||||
| self._pos_inf_count = pos_inf_count | |||||
| @property | |||||
| def max(self): | |||||
| """Get max value of tensor.""" | |||||
| return self._max | |||||
| @property | |||||
| def min(self): | |||||
| """Get min value of tensor.""" | |||||
| return self._min | |||||
| @property | |||||
| def avg(self): | |||||
| """Get avg value of tensor.""" | |||||
| return self._avg | |||||
| @property | |||||
| def count(self): | |||||
| """Get total count of tensor.""" | |||||
| return self._count | |||||
| @property | |||||
| def nan_count(self): | |||||
| """Get count of NAN.""" | |||||
| return self._nan_count | |||||
| @property | |||||
| def neg_inf_count(self): | |||||
| """Get count of negative INF.""" | |||||
| return self._neg_inf_count | |||||
| @property | |||||
| def pos_inf_count(self): | |||||
| """Get count of positive INF.""" | |||||
| return self._pos_inf_count | |||||
| class TensorUtils: | |||||
| """Tensor Utils class.""" | |||||
| @staticmethod | |||||
| def validate_dims_format(dims): | |||||
| """ | |||||
| Validate correct of format of dimension parameter. | |||||
| Args: | |||||
| dims (str): Dims of tensor. Its format is something like this "[0, 0, :, :]". | |||||
| Raises: | |||||
| ParamValueError: If format of dims is not correct. | |||||
| """ | |||||
| if dims is not None: | |||||
| if not isinstance(dims, str): | |||||
| raise ParamTypeError(dims, str) | |||||
| dims = dims.strip() | |||||
| if not (dims.startswith('[') and dims.endswith(']')): | |||||
| raise ParamValueError('The value: {} of dims must be ' | |||||
| 'start with `[` and end with `]`.'.format(dims)) | |||||
| for dim in dims[1:-1].split(','): | |||||
| dim = dim.strip() | |||||
| if dim == ":": | |||||
| continue | |||||
| if dim.startswith('-'): | |||||
| dim = dim[1:] | |||||
| if not dim.isdigit(): | |||||
| raise ParamValueError('The value: {} of dims in the square brackets ' | |||||
| 'must be int or `:`.'.format(dims)) | |||||
| @staticmethod | |||||
| def convert_array_from_str_dims(dims, limit=0): | |||||
| """ | |||||
| Convert string of dims data to array. | |||||
| Args: | |||||
| dims (str): Specify dims of tensor. | |||||
| limit (int): The max flexible dimension count, default value is 0 which means that there is no limitation. | |||||
| Returns: | |||||
| list, a string like this: "[0, 0, :, :]" will convert to this value: [0, 0, None, None]. | |||||
| Raises: | |||||
| ParamValueError, If flexible dimensions exceed limit value. | |||||
| """ | |||||
| dims = dims.strip().lstrip('[').rstrip(']') | |||||
| dims_list = [] | |||||
| count = 0 | |||||
| for dim in dims.split(','): | |||||
| dim = dim.strip() | |||||
| if dim == ':': | |||||
| dims_list.append(None) | |||||
| count += 1 | |||||
| else: | |||||
| dims_list.append(to_int(dim, "dim")) | |||||
| if limit and count > limit: | |||||
| raise ParamValueError("Flexible dimensions cannot exceed limit value: {}, size: {}" | |||||
| .format(limit, count)) | |||||
| return dims_list | |||||
| @staticmethod | |||||
| def get_specific_dims_data(ndarray, dims, tensor_dims): | |||||
| """ | |||||
| Get specific dims data. | |||||
| Args: | |||||
| ndarray (numpy.ndarray): An ndarray of numpy. | |||||
| dims (list): A list of specific dims. | |||||
| tensor_dims (list): A list of tensor dims. | |||||
| Returns: | |||||
| numpy.ndarray, an ndarray of specific dims tensor data. | |||||
| Raises: | |||||
| ParamValueError, If the length of param dims is not equal to the length of tensor dims or | |||||
| the index of param dims out of range. | |||||
| """ | |||||
| if len(dims) != len(tensor_dims): | |||||
| raise ParamValueError("The length of param dims: {}, is not equal to the " | |||||
| "length of tensor dims: {}.".format(len(dims), len(tensor_dims))) | |||||
| indices = [] | |||||
| for k, d in enumerate(dims): | |||||
| if d is not None: | |||||
| if d >= tensor_dims[k]: | |||||
| raise ParamValueError("The index: {} of param dims out of range: {}.".format(d, tensor_dims[k])) | |||||
| indices.append(d) | |||||
| else: | |||||
| indices.append(slice(0, tensor_dims[k])) | |||||
| result = ndarray[tuple(indices)] | |||||
| # Make sure the return type is numpy.ndarray. | |||||
| if not isinstance(result, np.ndarray): | |||||
| result = np.array(result) | |||||
| return result | |||||
| @staticmethod | |||||
| def get_statistics_from_tensor(tensors): | |||||
| """ | |||||
| Calculates statistics data of tensor. | |||||
| Args: | |||||
| tensors (numpy.ndarray): An numpy.ndarray of tensor data. | |||||
| Returns: | |||||
| an instance of Statistics. | |||||
| """ | |||||
| ma_value = np.ma.masked_invalid(tensors) | |||||
| total, valid = tensors.size, ma_value.count() | |||||
| invalids = [] | |||||
| for isfn in np.isnan, np.isposinf, np.isneginf: | |||||
| if total - valid > sum(invalids): | |||||
| count = np.count_nonzero(isfn(tensors)) | |||||
| invalids.append(count) | |||||
| else: | |||||
| invalids.append(0) | |||||
| nan_count, pos_inf_count, neg_inf_count = invalids | |||||
| if not valid: | |||||
| logger.warning('There are no valid values in the tensors(size=%d, shape=%s)', total, tensors.shape) | |||||
| statistics = Statistics(max_value=0, | |||||
| min_value=0, | |||||
| avg_value=0, | |||||
| count=total, | |||||
| nan_count=nan_count, | |||||
| neg_inf_count=neg_inf_count, | |||||
| pos_inf_count=pos_inf_count) | |||||
| return statistics | |||||
| # BUG: max of a masked array with dtype np.float16 returns inf | |||||
| # See numpy issue#15077 | |||||
| if issubclass(tensors.dtype.type, np.floating): | |||||
| tensor_min = ma_value.min(fill_value=np.PINF) | |||||
| tensor_max = ma_value.max(fill_value=np.NINF) | |||||
| if tensor_min < F32_MIN or tensor_max > F32_MAX: | |||||
| logger.warning('Values(%f, %f) are too large, you may encounter some undefined ' | |||||
| 'behaviours hereafter.', tensor_min, tensor_max) | |||||
| else: | |||||
| tensor_min = ma_value.min() | |||||
| tensor_max = ma_value.max() | |||||
| tensor_sum = ma_value.sum(dtype=np.float64) | |||||
| statistics = Statistics(max_value=tensor_max, | |||||
| min_value=tensor_min, | |||||
| avg_value=tensor_sum / valid, | |||||
| count=total, | |||||
| nan_count=nan_count, | |||||
| neg_inf_count=neg_inf_count, | |||||
| pos_inf_count=pos_inf_count) | |||||
| return statistics | |||||
| @staticmethod | |||||
| def get_statistics_dict(stats): | |||||
| """ | |||||
| Get statistics dict according to statistics value. | |||||
| Args: | |||||
| stats (Statistics): An instance of Statistics. | |||||
| Returns: | |||||
| dict, a dict including 'max', 'min', 'avg', 'count', 'nan_count', 'neg_inf_count', 'pos_inf_count'. | |||||
| """ | |||||
| statistics = { | |||||
| "max": float(stats.max), | |||||
| "min": float(stats.min), | |||||
| "avg": float(stats.avg), | |||||
| "count": stats.count, | |||||
| "nan_count": stats.nan_count, | |||||
| "neg_inf_count": stats.neg_inf_count, | |||||
| "pos_inf_count": stats.pos_inf_count | |||||
| } | |||||
| return statistics | |||||
| @staticmethod | |||||
| def calc_diff_between_two_tensor(first_tensor, second_tensor, tolerance): | |||||
| """ | |||||
| Calculate the difference between the first tensor and the second tensor. | |||||
| Args: | |||||
| first_tensor (numpy.ndarray): Specify the first tensor. | |||||
| second_tensor (numpy.ndarray): Specify the second tensor. | |||||
| tolerance (float): The tolerance of difference between the first tensor and the second tensor. | |||||
| Its is a percentage. The boundary value is equal to max(abs(min),abs(max)) * tolerance. | |||||
| The function of min and max is being used to calculate the min value and max value of | |||||
| the result of the first tensor subtract the second tensor. If the absolute value of | |||||
| result is less than or equal to boundary value, the result will set to be zero. | |||||
| Returns: | |||||
| tuple[numpy.ndarray, OverallDiffMetric], numpy.ndarray indicates the value of the first tensor | |||||
| subtract the second tensor and set the value to be zero when its less than or equal to tolerance. | |||||
| Raises: | |||||
| ParamTypeError: If the type of these two tensors is not the numpy.ndarray. | |||||
| ParamValueError: If the shape or dtype is not the same of these two tensors. | |||||
| """ | |||||
| if not isinstance(first_tensor, np.ndarray): | |||||
| raise ParamTypeError('first_tensor', np.ndarray) | |||||
| if not isinstance(second_tensor, np.ndarray): | |||||
| raise ParamTypeError('second_tensor', np.ndarray) | |||||
| if first_tensor.shape != second_tensor.shape: | |||||
| raise ParamValueError("the shape: {} of first tensor is not equal to shape: {} of second tensor." | |||||
| .format(first_tensor.shape, second_tensor.shape)) | |||||
| if first_tensor.dtype != second_tensor.dtype: | |||||
| raise ParamValueError("the dtype: {} of first tensor is not equal to dtype: {} of second tensor." | |||||
| .format(first_tensor.dtype, second_tensor.dtype)) | |||||
| diff_tensor = np.subtract(first_tensor, second_tensor) | |||||
| stats = TensorUtils.get_statistics_from_tensor(diff_tensor) | |||||
| boundary_value = max(abs(stats.max), abs(stats.min)) * tolerance | |||||
| is_close = np.isclose(first_tensor, second_tensor, atol=boundary_value, rtol=0) | |||||
| result = np.multiply(diff_tensor, ~is_close) | |||||
| return result | |||||
| @@ -16,4 +16,5 @@ Werkzeug>=1.0.0 | |||||
| tabulate>=0.8.6 | tabulate>=0.8.6 | ||||
| pandas>=1.0.4 | pandas>=1.0.4 | ||||
| yapf>=0.30.0 | yapf>=0.30.0 | ||||
| treelib>=1.6.1 | |||||
| treelib>=1.6.1 | |||||
| grpcio>=1.29.0 | |||||
| @@ -14,9 +14,11 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Test tensor container.""" | """Test tensor container.""" | ||||
| import unittest.mock as mock | import unittest.mock as mock | ||||
| import numpy as np | import numpy as np | ||||
| from mindinsight.datavisual.data_transform import tensor_container as tensor | from mindinsight.datavisual.data_transform import tensor_container as tensor | ||||
| from mindinsight.utils.tensor import TensorUtils | |||||
| class TestTensorContainer: | class TestTensorContainer: | ||||
| @@ -34,8 +36,9 @@ class TestTensorContainer: | |||||
| def test_get_statistics_from_tensor(self): | def test_get_statistics_from_tensor(self): | ||||
| """Tests get statistics from tensor.""" | """Tests get statistics from tensor.""" | ||||
| ndarray = np.array([1, 2, 3, 4, 5, float('-INF'), float('INF'), float('NAN')]).reshape([2, 2, 2]) | |||||
| statistics = tensor.get_statistics_from_tensor(ndarray) | |||||
| ndarray = np.array([1, 2, 3, 4, 5, float('-INF'), float('INF'), float('NAN')]).reshape( | |||||
| [2, 2, 2]) | |||||
| statistics = TensorUtils.get_statistics_from_tensor(ndarray) | |||||
| assert (statistics.max, statistics.min, statistics.avg, statistics.count, | assert (statistics.max, statistics.min, statistics.avg, statistics.count, | ||||
| statistics.nan_count, statistics.neg_inf_count, statistics.pos_inf_count) == \ | statistics.nan_count, statistics.neg_inf_count, statistics.pos_inf_count) == \ | ||||
| (5, 1, 3, 8, | (5, 1, 3, 8, | ||||
| @@ -43,8 +46,9 @@ class TestTensorContainer: | |||||
| def test_calc_original_buckets(self): | def test_calc_original_buckets(self): | ||||
| """Tests calculate original buckets.""" | """Tests calculate original buckets.""" | ||||
| ndarray = np.array([1, 2, 3, 4, 5, float('-INF'), float('INF'), float('NAN')]).reshape([2, 2, 2]) | |||||
| statistics = tensor.get_statistics_from_tensor(ndarray) | |||||
| ndarray = np.array([1, 2, 3, 4, 5, float('-INF'), float('INF'), float('NAN')]).reshape( | |||||
| [2, 2, 2]) | |||||
| statistics = TensorUtils.get_statistics_from_tensor(ndarray) | |||||
| buckets = tensor.calc_original_buckets(ndarray, statistics) | buckets = tensor.calc_original_buckets(ndarray, statistics) | ||||
| assert (buckets[0].left, buckets[0].width, buckets[0].count) == (1, 2, 2) | assert (buckets[0].left, buckets[0].width, buckets[0].count) == (1, 2, 2) | ||||
| @@ -29,10 +29,8 @@ from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | |||||
| from mindinsight.datavisual.common.exceptions import TensorNotExistError | from mindinsight.datavisual.common.exceptions import TensorNotExistError | ||||
| from mindinsight.datavisual.data_transform import data_manager | from mindinsight.datavisual.data_transform import data_manager | ||||
| from mindinsight.datavisual.data_transform.tensor_container import calc_original_buckets | from mindinsight.datavisual.data_transform.tensor_container import calc_original_buckets | ||||
| from mindinsight.datavisual.data_transform.tensor_container import get_statistics_from_tensor | |||||
| from mindinsight.datavisual.processors.tensor_processor import TensorProcessor | from mindinsight.datavisual.processors.tensor_processor import TensorProcessor | ||||
| from mindinsight.datavisual.processors.tensor_processor import get_specific_dims_data | |||||
| from mindinsight.datavisual.processors.tensor_processor import get_statistics_dict | |||||
| from mindinsight.utils.tensor import TensorUtils | |||||
| from mindinsight.datavisual.utils import crc32 | from mindinsight.datavisual.utils import crc32 | ||||
| from mindinsight.utils.exceptions import ParamValueError | from mindinsight.utils.exceptions import ParamValueError | ||||
| from mindinsight.utils.exceptions import ParamMissError | from mindinsight.utils.exceptions import ParamMissError | ||||
| @@ -187,7 +185,7 @@ class TestTensorProcessor: | |||||
| dims = expected_values.get('value').get("dims") | dims = expected_values.get('value').get("dims") | ||||
| expected_data = np.array(expected_values.get('value').get("float_data")).reshape(dims) | expected_data = np.array(expected_values.get('value').get("float_data")).reshape(dims) | ||||
| recv_tensor = np.array(recv_values.get('value').get("data")) | recv_tensor = np.array(recv_values.get('value').get("data")) | ||||
| expected_tensor = get_specific_dims_data(expected_data, [0, 0, None, None], dims) | |||||
| expected_tensor = TensorUtils.get_specific_dims_data(expected_data, [0, 0, None, None], dims) | |||||
| assert np.sum(np.isclose(recv_tensor, expected_tensor, rtol=1e-6) == 0) == 0 | assert np.sum(np.isclose(recv_tensor, expected_tensor, rtol=1e-6) == 0) == 0 | ||||
| @pytest.mark.usefixtures('load_tensor_record') | @pytest.mark.usefixtures('load_tensor_record') | ||||
| @@ -204,7 +202,8 @@ class TestTensorProcessor: | |||||
| assert recv_values.get('wall_time') == expected_values.get('wall_time') | assert recv_values.get('wall_time') == expected_values.get('wall_time') | ||||
| assert recv_values.get('step') == expected_values.get('step') | assert recv_values.get('step') == expected_values.get('step') | ||||
| expected_data = expected_values.get('value').get("float_data") | expected_data = expected_values.get('value').get("float_data") | ||||
| expected_statistic = get_statistics_dict(get_statistics_from_tensor(expected_data)) | |||||
| expected_statistic_instance = TensorUtils.get_statistics_from_tensor(expected_data) | |||||
| expected_statistic = TensorUtils.get_statistics_dict(expected_statistic_instance) | |||||
| recv_statistic = recv_values.get('value').get("statistics") | recv_statistic = recv_values.get('value').get("statistics") | ||||
| assert recv_statistic.get("max") - expected_statistic.get("max") < 1e-6 | assert recv_statistic.get("max") - expected_statistic.get("max") < 1e-6 | ||||
| assert recv_statistic.get("min") - expected_statistic.get("min") < 1e-6 | assert recv_statistic.get("min") - expected_statistic.get("min") < 1e-6 | ||||
| @@ -225,7 +224,7 @@ class TestTensorProcessor: | |||||
| assert recv_values.get('wall_time') == expected_values.get('wall_time') | assert recv_values.get('wall_time') == expected_values.get('wall_time') | ||||
| assert recv_values.get('step') == expected_values.get('step') | assert recv_values.get('step') == expected_values.get('step') | ||||
| expected_data = expected_values.get('value').get("float_data") | expected_data = expected_values.get('value').get("float_data") | ||||
| expected_statistic = get_statistics_from_tensor(expected_data) | |||||
| expected_statistic = TensorUtils.get_statistics_from_tensor(expected_data) | |||||
| expected_buckets = calc_original_buckets(expected_data, expected_statistic) | expected_buckets = calc_original_buckets(expected_data, expected_statistic) | ||||
| recv_buckets = recv_values.get('value').get("histogram_buckets") | recv_buckets = recv_values.get('value').get("histogram_buckets") | ||||