From 50e140050558f29993c6345b1ee054fd721fb656 Mon Sep 17 00:00:00 2001 From: yelihua Date: Fri, 11 Sep 2020 22:17:32 +0800 Subject: [PATCH] add debugger module --- mindinsight/backend/debugger/__init__.py | 26 + mindinsight/backend/debugger/debugger_api.py | 300 ++++ mindinsight/conf/defaults.py | 6 + .../data_transform/tensor_container.py | 118 +- .../datavisual/processors/tensor_processor.py | 114 +- mindinsight/datavisual/utils/tools.py | 19 + mindinsight/debugger/__init__.py | 20 + mindinsight/debugger/common/__init__.py | 14 + .../debugger/common/exceptions/__init__.py | 14 + .../debugger/common/exceptions/error_code.py | 56 + .../debugger/common/exceptions/exceptions.py | 117 ++ mindinsight/debugger/common/log.py | 20 + mindinsight/debugger/common/utils.py | 168 ++ mindinsight/debugger/debugger_cache.py | 154 ++ mindinsight/debugger/debugger_grpc_server.py | 309 ++++ mindinsight/debugger/debugger_server.py | 752 +++++++++ mindinsight/debugger/proto/debug_grpc.proto | 113 ++ mindinsight/debugger/proto/debug_grpc_pb2.py | 683 ++++++++ .../debugger/proto/debug_grpc_pb2_grpc.py | 193 +++ mindinsight/debugger/proto/ms_graph.proto | 322 ++++ mindinsight/debugger/proto/ms_graph_pb2.py | 1395 +++++++++++++++++ mindinsight/debugger/stream_cache/__init__.py | 14 + .../debugger/stream_cache/debugger_graph.py | 289 ++++ mindinsight/debugger/stream_cache/node.py | 61 + mindinsight/debugger/stream_cache/tensor.py | 233 +++ .../debugger/stream_cache/watchpoint.py | 300 ++++ .../debugger/stream_handler/__init__.py | 23 + .../debugger/stream_handler/base_handler.py | 34 + .../debugger/stream_handler/event_handler.py | 159 ++ .../debugger/stream_handler/graph_handler.py | 314 ++++ .../stream_handler/metadata_handler.py | 131 ++ .../debugger/stream_handler/tensor_handler.py | 298 ++++ .../stream_handler/watchpoint_handler.py | 333 ++++ mindinsight/scripts/start.py | 48 +- mindinsight/utils/constant.py | 5 + mindinsight/utils/tensor.py | 298 ++++ requirements.txt | 3 +- .../data_transform/test_tensor_container.py | 12 +- .../processors/test_tensor_processor.py | 11 +- 39 files changed, 7244 insertions(+), 235 deletions(-) create mode 100644 mindinsight/backend/debugger/__init__.py create mode 100644 mindinsight/backend/debugger/debugger_api.py create mode 100644 mindinsight/debugger/__init__.py create mode 100644 mindinsight/debugger/common/__init__.py create mode 100644 mindinsight/debugger/common/exceptions/__init__.py create mode 100644 mindinsight/debugger/common/exceptions/error_code.py create mode 100644 mindinsight/debugger/common/exceptions/exceptions.py create mode 100644 mindinsight/debugger/common/log.py create mode 100644 mindinsight/debugger/common/utils.py create mode 100644 mindinsight/debugger/debugger_cache.py create mode 100644 mindinsight/debugger/debugger_grpc_server.py create mode 100644 mindinsight/debugger/debugger_server.py create mode 100644 mindinsight/debugger/proto/debug_grpc.proto create mode 100644 mindinsight/debugger/proto/debug_grpc_pb2.py create mode 100644 mindinsight/debugger/proto/debug_grpc_pb2_grpc.py create mode 100644 mindinsight/debugger/proto/ms_graph.proto create mode 100644 mindinsight/debugger/proto/ms_graph_pb2.py create mode 100644 mindinsight/debugger/stream_cache/__init__.py create mode 100644 mindinsight/debugger/stream_cache/debugger_graph.py create mode 100644 mindinsight/debugger/stream_cache/node.py create mode 100644 mindinsight/debugger/stream_cache/tensor.py create mode 100644 mindinsight/debugger/stream_cache/watchpoint.py create mode 100644 mindinsight/debugger/stream_handler/__init__.py create mode 100644 mindinsight/debugger/stream_handler/base_handler.py create mode 100644 mindinsight/debugger/stream_handler/event_handler.py create mode 100644 mindinsight/debugger/stream_handler/graph_handler.py create mode 100644 mindinsight/debugger/stream_handler/metadata_handler.py create mode 100644 mindinsight/debugger/stream_handler/tensor_handler.py create mode 100644 mindinsight/debugger/stream_handler/watchpoint_handler.py create mode 100644 mindinsight/utils/tensor.py diff --git a/mindinsight/backend/debugger/__init__.py b/mindinsight/backend/debugger/__init__.py new file mode 100644 index 00000000..41064a87 --- /dev/null +++ b/mindinsight/backend/debugger/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Module init file.""" +from mindinsight.backend.debugger.debugger_api import init_module as init_query_module + + +def init_module(app): + """ + Init module entry. + + Args: + app (Flask): A Flask instance. + """ + init_query_module(app) diff --git a/mindinsight/backend/debugger/debugger_api.py b/mindinsight/backend/debugger/debugger_api.py new file mode 100644 index 00000000..2b46c02e --- /dev/null +++ b/mindinsight/backend/debugger/debugger_api.py @@ -0,0 +1,300 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Debugger restful api.""" +import json + +from flask import Blueprint, jsonify, request + +from mindinsight.conf import settings +from mindinsight.debugger.debugger_server import DebuggerServer +from mindinsight.utils.exceptions import ParamValueError + +BLUEPRINT = Blueprint("debugger", __name__, + url_prefix=settings.URL_PATH_PREFIX + settings.API_PREFIX) + + +def _initialize_debugger_server(): + """Initialize a debugger server instance.""" + port = settings.DEBUGGER_PORT if hasattr(settings, 'DEBUGGER_PORT') else None + enable_debugger = settings.ENABLE_DEBUGGER if hasattr(settings, 'ENABLE_DEBUGGER') else False + server = None + if port and enable_debugger: + server = DebuggerServer(port) + return server + + +def _read_post_request(post_request): + """ + Extract the body of post request. + + Args: + post_request (object): The post request. + + Returns: + dict, the deserialized body of request. + """ + body = post_request.stream.read() + try: + body = json.loads(body if body else "{}") + except Exception: + raise ParamValueError("Json data parse failed.") + return body + + +def _wrap_reply(func, *args, **kwargs): + """Serialize reply.""" + reply = func(*args, **kwargs) + return jsonify(reply) + + +@BLUEPRINT.route("/debugger/poll_data", methods=["GET"]) +def poll_data(): + """ + Wait for data to be updated on UI. + + Get data from server and display the change on UI. + + Returns: + str, the updated data. + + Examples: + >>> Get http://xxxx/v1/mindinsight/debugger/poll_data?pos=xx + """ + pos = request.args.get('pos') + + reply = _wrap_reply(BACKEND_SERVER.poll_data, pos) + + return reply + + +@BLUEPRINT.route("/debugger/search", methods=["GET"]) +def search(): + """ + Search nodes in specified watchpoint. + + Returns: + str, the required data. + + Examples: + >>> Get http://xxxx/v1/mindinsight/debugger/retrive?mode=all + """ + name = request.args.get('name') + watch_point_id = int(request.args.get('watch_point_id', 0)) + reply = _wrap_reply(BACKEND_SERVER.search, name, watch_point_id) + + return reply + + +@BLUEPRINT.route("/debugger/retrieve_node_by_bfs", methods=["GET"]) +def retrieve_node_by_bfs(): + """ + Search node by bfs. + + Returns: + str, the required data. + + Examples: + >>> Get http://xxxx/v1/mindinsight/debugger/retrieve_node_by_bfs?name=node_name&ascend=true + """ + name = request.args.get('name') + ascend = request.args.get('ascend', 'false') + ascend = ascend == 'true' + reply = _wrap_reply(BACKEND_SERVER.retrieve_node_by_bfs, name, ascend) + + return reply + + +@BLUEPRINT.route("/debugger/tensor-comparisons", methods=["GET"]) +def tensor_comparisons(): + """ + Get tensor comparisons. + + Returns: + str, the required data. + + Examples: + >>> Get http://xxxx/v1/mindinsight/debugger/tensor-comparisons? + name=node_name&detail=data&shape=[0, 0, :, :]&tolerance=0.5 + """ + name = request.args.get('name') + detail = request.args.get('detail', 'data') + shape = request.args.get('shape') + tolerance = request.args.get('tolerance', '0') + reply = _wrap_reply(BACKEND_SERVER.tensor_comparisons, name, shape, detail, tolerance) + + return reply + + +@BLUEPRINT.route("/debugger/retrieve", methods=["POST"]) +def retrieve(): + """ + Retrieve data according to mode and params. + + Returns: + str, the required data. + + Examples: + >>> POST http://xxxx/v1/mindinsight/debugger/retrieve + """ + body = _read_post_request(request) + mode = body.get('mode') + params = body.get('params') + reply = _wrap_reply(BACKEND_SERVER.retrieve, mode, params) + return reply + + +@BLUEPRINT.route("/debugger/retrieve_tensor_history", methods=["POST"]) +def retrieve_tensor_history(): + """ + Retrieve data according to mode and params. + + Returns: + str, the required data. + + Examples: + >>> POST http://xxxx/v1/mindinsight/debugger/retrieve_tensor_history + """ + body = _read_post_request(request) + name = body.get('name') + reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_history, name) + return reply + + +@BLUEPRINT.route("/debugger/tensors", methods=["GET"]) +def retrieve_tensor_value(): + """ + Retrieve tensor value according to name and shape. + + Returns: + str, the required data. + + Examples: + >>> GET http://xxxx/v1/mindinsight/debugger/tensors?name=node_name&detail=data&shape=[1,1,:,:] + """ + name = request.args.get('name') + detail = request.args.get('detail') + shape = request.args.get('shape') + reply = _wrap_reply(BACKEND_SERVER.retrieve_tensor_value, name, detail, shape) + return reply + + +@BLUEPRINT.route("/debugger/create_watchpoint", methods=["POST"]) +def create_watchpoint(): + """ + Create watchpoint. + + Returns: + str, watchpoint id. + + Raises: + MindInsightException: If method fails to be called. + ParamValueError: If parsing json data search_condition fails. + + Examples: + >>> POST http://xxxx/v1/mindinsight/debugger/create_watchpoint + """ + body = _read_post_request(request) + + condition = body.get('condition') + watch_nodes = body.get('watch_nodes') + watch_point_id = body.get('watch_point_id') + reply = _wrap_reply(BACKEND_SERVER.create_watchpoint, condition, watch_nodes, watch_point_id) + return reply + + +@BLUEPRINT.route("/debugger/update_watchpoint", methods=["POST"]) +def update_watchpoint(): + """ + Update watchpoint. + + Returns: + str, reply message. + + Raises: + MindInsightException: If method fails to be called. + ParamValueError: If parsing json data search_condition fails. + + Examples: + >>> POST http://xxxx/v1/mindinsight/debugger/update_watchpoint + """ + body = _read_post_request(request) + + watch_point_id = body.get('watch_point_id') + watch_nodes = body.get('watch_nodes') + mode = body.get('mode') + name = body.get('name') + reply = _wrap_reply(BACKEND_SERVER.update_watchpoint, watch_point_id, watch_nodes, mode, name) + + return reply + + +@BLUEPRINT.route("/debugger/delete_watchpoint", methods=["POST"]) +def delete_watchpoint(): + """ + delete watchpoint. + + Returns: + str, reply message. + + Raises: + MindInsightException: If method fails to be called. + ParamValueError: If parsing json data search_condition fails. + + Examples: + >>> POST http://xxxx/v1/mindinsight/debugger/delete_watchpoint + """ + body = _read_post_request(request) + + watch_point_id = body.get('watch_point_id') + + reply = _wrap_reply(BACKEND_SERVER.delete_watchpoint, watch_point_id) + + return reply + + +@BLUEPRINT.route("/debugger/control", methods=["POST"]) +def control(): + """ + Control request. + + Returns: + str, reply message. + + Raises: + MindInsightException: If method fails to be called. + ParamValueError: If parsing json data search_condition fails. + + Examples: + >>> POST http://xxxx/v1/mindinsight/debugger/control + """ + params = _read_post_request(request) + reply = _wrap_reply(BACKEND_SERVER.control, params) + + return reply + + +BACKEND_SERVER = _initialize_debugger_server() + + +def init_module(app): + """ + Init module entry. + + Args: + app (Flask): The application obj. + """ + app.register_blueprint(BLUEPRINT) + if BACKEND_SERVER: + BACKEND_SERVER.start() diff --git a/mindinsight/conf/defaults.py b/mindinsight/conf/defaults.py index b554869e..beaf4b65 100644 --- a/mindinsight/conf/defaults.py +++ b/mindinsight/conf/defaults.py @@ -26,6 +26,12 @@ WORKSPACE = os.path.join(os.environ['HOME'], 'mindinsight') PORT = 8080 URL_PATH_PREFIX = '' +#################################### +# Debugger default settings. +#################################### +DEBUGGER_PORT = '50051' +ENABLE_DEBUGGER = False + #################################### # Datavisual default settings. #################################### diff --git a/mindinsight/datavisual/data_transform/tensor_container.py b/mindinsight/datavisual/data_transform/tensor_container.py index ef430b38..494f8ed3 100644 --- a/mindinsight/datavisual/data_transform/tensor_container.py +++ b/mindinsight/datavisual/data_transform/tensor_container.py @@ -15,128 +15,14 @@ """Tensor data container.""" import numpy as np -from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket from mindinsight.datavisual.utils.utils import calc_histogram_bins from mindinsight.utils.exceptions import ParamValueError +from mindinsight.utils.tensor import TensorUtils -F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max MAX_TENSOR_COUNT = 10000000 -class Statistics: - """Statistics data class. - - Args: - max_value (float): max value of tensor data. - min_value (float): min value of tensor data. - avg_value (float): avg value of tensor data. - count (int): total count of tensor data. - nan_count (int): count of NAN. - neg_inf_count (int): count of negative INF. - pos_inf_count (int): count of positive INF. - """ - - def __init__(self, max_value=0, min_value=0, avg_value=0, - count=0, nan_count=0, neg_inf_count=0, pos_inf_count=0): - self._max = max_value - self._min = min_value - self._avg = avg_value - self._count = count - self._nan_count = nan_count - self._neg_inf_count = neg_inf_count - self._pos_inf_count = pos_inf_count - - @property - def max(self): - """Get max value of tensor.""" - return self._max - - @property - def min(self): - """Get min value of tensor.""" - return self._min - - @property - def avg(self): - """Get avg value of tensor.""" - return self._avg - - @property - def count(self): - """Get total count of tensor.""" - return self._count - - @property - def nan_count(self): - """Get count of NAN.""" - return self._nan_count - - @property - def neg_inf_count(self): - """Get count of negative INF.""" - return self._neg_inf_count - - @property - def pos_inf_count(self): - """Get count of positive INF.""" - return self._pos_inf_count - - -def get_statistics_from_tensor(tensors): - """ - Calculates statistics data of tensor. - - Args: - tensors (numpy.ndarray): An numpy.ndarray of tensor data. - - Returns: - an instance of Statistics. - """ - ma_value = np.ma.masked_invalid(tensors) - total, valid = tensors.size, ma_value.count() - invalids = [] - for isfn in np.isnan, np.isposinf, np.isneginf: - if total - valid > sum(invalids): - count = np.count_nonzero(isfn(tensors)) - invalids.append(count) - else: - invalids.append(0) - - nan_count, pos_inf_count, neg_inf_count = invalids - if not valid: - logger.warning('There are no valid values in the tensors(size=%d, shape=%s)', total, tensors.shape) - statistics = Statistics(max_value=0, - min_value=0, - avg_value=0, - count=total, - nan_count=nan_count, - neg_inf_count=neg_inf_count, - pos_inf_count=pos_inf_count) - return statistics - - # BUG: max of a masked array with dtype np.float16 returns inf - # See numpy issue#15077 - if issubclass(tensors.dtype.type, np.floating): - tensor_min = ma_value.min(fill_value=np.PINF) - tensor_max = ma_value.max(fill_value=np.NINF) - if tensor_min < F32_MIN or tensor_max > F32_MAX: - logger.warning('Values(%f, %f) are too large, you may encounter some undefined ' - 'behaviours hereafter.', tensor_min, tensor_max) - else: - tensor_min = ma_value.min() - tensor_max = ma_value.max() - tensor_sum = ma_value.sum(dtype=np.float64) - statistics = Statistics(max_value=tensor_max, - min_value=tensor_min, - avg_value=tensor_sum / valid, - count=total, - nan_count=nan_count, - neg_inf_count=neg_inf_count, - pos_inf_count=pos_inf_count) - return statistics - - def calc_original_buckets(np_value, stats): """ Calculate buckets from tensor data. @@ -188,7 +74,7 @@ class TensorContainer: self._dims = tuple(tensor_message.dims) self._data_type = tensor_message.data_type self._np_array = self.get_ndarray(tensor_message.float_data) - self._stats = get_statistics_from_tensor(self._np_array) + self._stats = TensorUtils.get_statistics_from_tensor(self._np_array) original_buckets = calc_original_buckets(self._np_array, self._stats) self._count = sum(bucket.count for bucket in original_buckets) self._max = self._stats.max diff --git a/mindinsight/datavisual/processors/tensor_processor.py b/mindinsight/datavisual/processors/tensor_processor.py index 27a01ffd..4fd96c54 100644 --- a/mindinsight/datavisual/processors/tensor_processor.py +++ b/mindinsight/datavisual/processors/tensor_processor.py @@ -19,97 +19,16 @@ import numpy as np from mindinsight.datavisual.utils.tools import to_int from mindinsight.utils.exceptions import ParamValueError, UrlDecodeError +from mindinsight.utils.tensor import TensorUtils from mindinsight.conf.constants import MAX_TENSOR_RESPONSE_DATA_SIZE from mindinsight.datavisual.common.validation import Validation from mindinsight.datavisual.common.exceptions import StepTensorDataNotInCacheError, TensorNotExistError from mindinsight.datavisual.common.exceptions import ResponseDataExceedMaxValueError -from mindinsight.datavisual.data_transform.tensor_container import TensorContainer, get_statistics_from_tensor +from mindinsight.datavisual.data_transform.tensor_container import TensorContainer from mindinsight.datavisual.processors.base_processor import BaseProcessor from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 -def convert_array_from_str(dims, limit=0): - """ - Convert string of dims data to array. - - Args: - dims (str): Specify dims of tensor. - limit (int): The max flexible dimension count, default value is 0 which means that there is no limitation. - - Returns: - list, a string like this: "[0, 0, :, :]" will convert to this value: [0, 0, None, None]. - - Raises: - ParamValueError, If flexible dimensions exceed limit value. - """ - dims = dims.strip().lstrip('[').rstrip(']') - dims_list = [] - count = 0 - for dim in dims.split(','): - dim = dim.strip() - if dim == ':': - dims_list.append(None) - count += 1 - else: - dims_list.append(to_int(dim, "dim")) - if limit and count > limit: - raise ParamValueError("Flexible dimensions cannot exceed limit value: {}, size: {}" - .format(limit, count)) - return dims_list - - -def get_specific_dims_data(ndarray, dims, tensor_dims): - """ - Get specific dims data. - - Args: - ndarray (numpy.ndarray): An ndarray of numpy. - dims (list): A list of specific dims. - tensor_dims (list): A list of tensor dims. - - Returns: - numpy.ndarray, an ndarray of specific dims tensor data. - - Raises: - ParamValueError, If the length of param dims is not equal to the length of tensor dims or - the index of param dims out of range. - """ - if len(dims) != len(tensor_dims): - raise ParamValueError("The length of param dims: {}, is not equal to the " - "length of tensor dims: {}.".format(len(dims), len(tensor_dims))) - indices = [] - for k, d in enumerate(dims): - if d is not None: - if d >= tensor_dims[k]: - raise ParamValueError("The index: {} of param dims out of range: {}.".format(d, tensor_dims[k])) - indices.append(d) - else: - indices.append(slice(0, tensor_dims[k])) - return ndarray[tuple(indices)] - - -def get_statistics_dict(stats): - """ - Get statistics dict according to statistics value. - - Args: - stats (Statistics): An instance of Statistics. - - Returns: - dict, a dict including 'max', 'min', 'avg', 'count', 'nan_count', 'neg_inf_count', 'pos_inf_count'. - """ - statistics = { - "max": float(stats.max), - "min": float(stats.min), - "avg": float(stats.avg), - "count": stats.count, - "nan_count": stats.nan_count, - "neg_inf_count": stats.neg_inf_count, - "pos_inf_count": stats.pos_inf_count - } - return statistics - - class TensorProcessor(BaseProcessor): """Tensor Processor.""" def get_tensors(self, train_ids, tags, step, dims, detail): @@ -130,22 +49,7 @@ class TensorProcessor(BaseProcessor): UrlDecodeError, If unquote train id error with strict mode. """ Validation.check_param_empty(train_id=train_ids, tag=tags) - if dims is not None: - if not isinstance(dims, str): - raise ParamValueError('The type of dims must be str, but got {}.'.format(type(dims))) - dims = dims.strip() - if not (dims.startswith('[') and dims.endswith(']')): - raise ParamValueError('The value: {} of dims must be ' - 'start with `[` and end with `]`.'.format(dims)) - for dim in dims[1:-1].split(','): - dim = dim.strip() - if dim == ":": - continue - if dim.startswith('-'): - dim = dim[1:] - if not dim.isdigit(): - raise ParamValueError('The value: {} of dims in the square brackets ' - 'must be int or `:`.'.format(dims)) + TensorUtils.validate_dims_format(dims) for index, train_id in enumerate(train_ids): try: @@ -248,7 +152,7 @@ class TensorProcessor(BaseProcessor): "data_type": anf_ir_pb2.DataType.Name(value.data_type) } if detail and detail == 'stats': - stats = get_statistics_dict(value.stats) + stats = TensorUtils.get_statistics_dict(value.stats) value_dict.update({"statistics": stats}) values.append({ @@ -295,14 +199,14 @@ class TensorProcessor(BaseProcessor): """ values = [] step_in_cache = False - dims = convert_array_from_str(dims, limit=2) + dims = TensorUtils.convert_array_from_str_dims(dims, limit=2) for tensor in tensors: # This value is an instance of TensorContainer value = tensor.value if step != tensor.step: continue step_in_cache = True - res_data = get_specific_dims_data(value.ndarray, dims, list(value.dims)) + res_data = TensorUtils.get_specific_dims_data(value.ndarray, dims, list(value.dims)) flatten_data = res_data.flatten().tolist() if len(flatten_data) > MAX_TENSOR_RESPONSE_DATA_SIZE: raise ResponseDataExceedMaxValueError("the size of response data: {} exceed max value: {}." @@ -328,7 +232,7 @@ class TensorProcessor(BaseProcessor): transfer_data[index] = float(data) return transfer_data - stats = get_statistics_from_tensor(res_data) + stats = TensorUtils.get_statistics_from_tensor(res_data) if stats.nan_count + stats.neg_inf_count + stats.pos_inf_count > 0: tensor_data = transfer(res_data) else: @@ -340,7 +244,7 @@ class TensorProcessor(BaseProcessor): "dims": value.dims, "data_type": anf_ir_pb2.DataType.Name(value.data_type), "data": tensor_data, - "statistics": get_statistics_dict(stats) + "statistics": TensorUtils.get_statistics_dict(stats) } }) break @@ -389,7 +293,7 @@ class TensorProcessor(BaseProcessor): "dims": value.dims, "data_type": anf_ir_pb2.DataType.Name(value.data_type), "histogram_buckets": buckets, - "statistics": get_statistics_dict(value.stats) + "statistics": TensorUtils.get_statistics_dict(value.stats) } }) diff --git a/mindinsight/datavisual/utils/tools.py b/mindinsight/datavisual/utils/tools.py index 05b4b4ea..d49f7c34 100644 --- a/mindinsight/datavisual/utils/tools.py +++ b/mindinsight/datavisual/utils/tools.py @@ -80,6 +80,25 @@ def to_int(param, param_name): return param +def to_float(param, param_name): + """ + Transfer param to float type. + + Args: + param (Any): A param transformed. + param_name (str): Param name. + + Returns: + float, value after transformed. + + """ + try: + param = float(param) + except ValueError: + raise exceptions.ParamTypeError(param_name, 'Float') + return param + + def str_to_bool(param, param_name): """ Check param and transform it to bool. diff --git a/mindinsight/debugger/__init__.py b/mindinsight/debugger/__init__.py new file mode 100644 index 00000000..637a3ff8 --- /dev/null +++ b/mindinsight/debugger/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Debugger Module Introduction. + +This module provides Python APIs to retrieve the debugger info and control the training process. +The APIs can help users to understand the training process and find the bugs in training script. +""" diff --git a/mindinsight/debugger/common/__init__.py b/mindinsight/debugger/common/__init__.py new file mode 100644 index 00000000..e3077430 --- /dev/null +++ b/mindinsight/debugger/common/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/mindinsight/debugger/common/exceptions/__init__.py b/mindinsight/debugger/common/exceptions/__init__.py new file mode 100644 index 00000000..e3077430 --- /dev/null +++ b/mindinsight/debugger/common/exceptions/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/mindinsight/debugger/common/exceptions/error_code.py b/mindinsight/debugger/common/exceptions/error_code.py new file mode 100644 index 00000000..52cf88ce --- /dev/null +++ b/mindinsight/debugger/common/exceptions/error_code.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Debugger error code and messages.""" +from enum import Enum, unique +from mindinsight.utils.constant import DebuggerErrors as DebuggerErrorCodes + + +_PARAM_ERROR_MASK = 0b00001 << 7 +_DEBUGGER_GRAPH_ERROR = 0b00010 << 7 +_DEBUGGER_RUNNING_ERROR = 0b00011 << 7 + + +@unique +class DebuggerErrors(DebuggerErrorCodes): + """Debugger error codes.""" + PARAM_TYPE_ERROR = 0 | _PARAM_ERROR_MASK + PARAM_VALUE_ERROR = 1 | _PARAM_ERROR_MASK + + NODE_NOT_IN_GRAPH_ERROR = 0 | _DEBUGGER_GRAPH_ERROR + GRAPH_NOT_EXIST_ERROR = 1 | _DEBUGGER_GRAPH_ERROR + + CREATE_WATCHPOINT_ERROR = 0 | _DEBUGGER_RUNNING_ERROR + UPDATE_WATCHPOINT_ERROR = 1 | _DEBUGGER_RUNNING_ERROR + DELETE_WATCHPOINT_ERROR = 2 | _DEBUGGER_RUNNING_ERROR + CONTINUE_ERROR = 3 | _DEBUGGER_RUNNING_ERROR + PAUSE_ERROR = 4 | _DEBUGGER_RUNNING_ERROR + COMPARE_TENSOR_ERROR = 5 | _DEBUGGER_RUNNING_ERROR + + +@unique +class DebuggerErrorMsg(Enum): + """Debugger error messages.""" + PARAM_TYPE_ERROR = "TypeError. {}" + PARAM_VALUE_ERROR = "ValueError. {}" + PARAM_MISSING_ERROR = "MissingError. {}" + UNEXPECTED_EXCEPTION_ERROR = "Unexpected exception. {}" + + GRAPH_NOT_EXIST_ERROR = "The graph does not exist." + + CREATE_WATCHPOINT_ERROR = "Create watchpoint failed. {}" + UPDATE_WATCHPOINT_ERROR = "Update watchpoint failed. {}" + DELETE_WATCHPOINT_ERROR = "Delete watchpoint failed. {}" + CONTINUE_ERROR = "Continue debugging failed. {}" + PAUSE_ERROR = "Pause debugging failed. {}" diff --git a/mindinsight/debugger/common/exceptions/exceptions.py b/mindinsight/debugger/common/exceptions/exceptions.py new file mode 100644 index 00000000..81489a55 --- /dev/null +++ b/mindinsight/debugger/common/exceptions/exceptions.py @@ -0,0 +1,117 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Definition of error code and relative messages in debugger module.""" +from mindinsight.utils.exceptions import MindInsightException +from mindinsight.debugger.common.exceptions.error_code import DebuggerErrors, DebuggerErrorMsg + + +class DebuggerParamTypeError(MindInsightException): + """The parameter type error in debugger module.""" + + def __init__(self, msg): + super(DebuggerParamTypeError, self).__init__( + error=DebuggerErrors.PARAM_TYPE_ERROR, + message=DebuggerErrorMsg.PARAM_TYPE_ERROR.value.format(msg) + ) + + +class DebuggerParamValueError(MindInsightException): + """The parameter value error in debugger module.""" + + def __init__(self, msg): + super(DebuggerParamValueError, self).__init__( + error=DebuggerErrors.PARAM_VALUE_ERROR, + message=DebuggerErrorMsg.PARAM_VALUE_ERROR.value.format(msg) + ) + + +class DebuggerCreateWatchPointError(MindInsightException): + """The error about creating watch point.""" + + def __init__(self, msg): + super(DebuggerCreateWatchPointError, self).__init__( + error=DebuggerErrors.CREATE_WATCHPOINT_ERROR, + message=DebuggerErrorMsg.CREATE_WATCHPOINT_ERROR.value.format(msg) + ) + + +class DebuggerUpdateWatchPointError(MindInsightException): + """The error about updating watch point.""" + + def __init__(self, msg): + super(DebuggerUpdateWatchPointError, self).__init__( + error=DebuggerErrors.UPDATE_WATCHPOINT_ERROR, + message=DebuggerErrorMsg.UPDATE_WATCHPOINT_ERROR.value.format(msg) + ) + + +class DebuggerDeleteWatchPointError(MindInsightException): + """The error about deleting watch point.""" + + def __init__(self, msg): + super(DebuggerDeleteWatchPointError, self).__init__( + error=DebuggerErrors.DELETE_WATCHPOINT_ERROR, + message=DebuggerErrorMsg.DELETE_WATCHPOINT_ERROR.value.format(msg) + ) + + +class DebuggerCompareTensorError(MindInsightException): + """The error about comparing tensors.""" + + def __init__(self, msg): + super(DebuggerCompareTensorError, self).__init__( + error=DebuggerErrors.COMPARE_TENSOR_ERROR, + message=DebuggerErrorMsg.COMPARE_TENSOR_ERROR.value.format(msg) + ) + + +class DebuggerContinueError(MindInsightException): + """The error about continuing debugging.""" + def __init__(self, msg): + super(DebuggerContinueError, self).__init__( + error=DebuggerErrors.CONTINUE_ERROR, + message=DebuggerErrorMsg.CONTINUE_ERROR.value.format(msg) + ) + + +class DebuggerPauseError(MindInsightException): + """The error about pausing debugging.""" + def __init__(self, msg): + super(DebuggerPauseError, self).__init__( + error=DebuggerErrors.PAUSE_ERROR, + message=DebuggerErrorMsg.PAUSE_ERROR.value.format(msg) + ) + + +class DebuggerNodeNotInGraphError(MindInsightException): + """The node is not in the graph.""" + def __init__(self, node_name, node_type=None): + if node_type is not None: + err_msg = f"Cannot find the node in graph by the given name. node name: {node_name}, type: {node_type}." + else: + err_msg = f"Cannot find the node in graph by the given name. node name: {node_name}." + super(DebuggerNodeNotInGraphError, self).__init__( + error=DebuggerErrors.NODE_NOT_IN_GRAPH_ERROR, + message=err_msg + ) + + +class DebuggerGraphNotExistError(MindInsightException): + """The graph does not exist.""" + def __init__(self): + super(DebuggerGraphNotExistError, self).__init__( + error=DebuggerErrors.GRAPH_NOT_EXIST_ERROR, + message=DebuggerErrorMsg.GRAPH_NOT_EXIST_ERROR.value + ) diff --git a/mindinsight/debugger/common/log.py b/mindinsight/debugger/common/log.py new file mode 100644 index 00000000..ccbef0eb --- /dev/null +++ b/mindinsight/debugger/common/log.py @@ -0,0 +1,20 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Import mindinsight unified log module.""" +from mindinsight.utils.log import setup_logger + +LOG_NAME = "debugger" +LOG_MODULE = "debugger" +logger = setup_logger(sub_module=LOG_MODULE, log_name=LOG_NAME) diff --git a/mindinsight/debugger/common/utils.py b/mindinsight/debugger/common/utils.py new file mode 100644 index 00000000..b8a1bac7 --- /dev/null +++ b/mindinsight/debugger/common/utils.py @@ -0,0 +1,168 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Define the utils.""" +import enum +from collections import namedtuple + +import numpy as np + +from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError +from mindinsight.debugger.common.log import logger as log +from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply +from mindinsight.debugger.stream_cache.debugger_graph import NodeTypeEnum + +# translate the MindSpore type to numpy type. +NUMPY_TYPE_MAP = { + 'DT_BOOL': np.bool, + + 'DT_INT8': np.int8, + 'DT_INT16': np.int16, + 'DT_INT32': np.int32, + 'DT_INT64': np.int64, + + 'DT_UINT8': np.uint8, + 'DT_UINT16': np.uint16, + 'DT_UINT32': np.uint32, + 'DT_UINT64': np.uint64, + + 'DT_FLOAT16': np.float16, + 'DT_FLOAT32': np.float32, + 'DT_FLOAT64': np.float64, + + 'DT_STRING': np.str +} + + +@enum.unique +class ReplyStates(enum.Enum): + """Define the status of reply.""" + SUCCESS = 0 + FAILED = -1 + + +@enum.unique +class ServerStatus(enum.Enum): + """The status of debugger server.""" + PENDING = 'pending' # no client session has been connected + RECEIVE_GRAPH = 'receive graph' # the client session has sent the graph + WAITING = 'waiting' # the client session is ready + RUNNING = 'running' # the client session is running a script + + +@enum.unique +class Streams(enum.Enum): + """Define the enable streams to be deal with.""" + + COMMAND = "command" + DATA = "data" + METADATA = "metadata" + GRAPH = 'node' + TENSOR = 'tensor' + WATCHPOINT = 'watchpoint' + WATCHPOINT_HIT = 'watchpoint_hit' + + +NodeBasicInfo = namedtuple('node_basic_info', ['name', 'full_name', 'type']) + + +def get_ack_reply(state=0): + """The the ack EventReply.""" + reply = EventReply() + state_mapping = { + 0: EventReply.Status.OK, + 1: EventReply.Status.FAILED, + 2: EventReply.Status.PENDING + } + reply.status = state_mapping[state] + + return reply + + +def wrap_reply_response(error_code=None, error_message=None): + """ + Wrap reply response. + + Args: + error_code (str): Error code. Default: None. + error_message (str): Error message. Default: None. + + Returns: + str, serialized response. + """ + if error_code is None: + reply = {'state': ReplyStates.SUCCESS.value} + else: + reply = { + 'state': ReplyStates.FAILED.value, + 'error_code': error_code, + 'error_message': error_message + } + + return reply + + +def create_view_event_from_tensor_history(tensor_history): + """ + Create view event reply according to tensor names. + + Args: + tensor_history (list[dict]): The list of tensor history. Each element has keys: + `name`, `node_type`. + + Returns: + EventReply, the event reply with view cmd. + """ + view_event = get_ack_reply() + for tensor_info in tensor_history: + node_type = tensor_info.get('node_type') + if node_type == NodeTypeEnum.CONST.value: + continue + truncate_tag = tensor_info.get('node_type') == NodeTypeEnum.PARAMETER.value + tensor_name = tensor_info.get('full_name', '') + # create view command + ms_tensor = view_event.view_cmd.tensors.add() + ms_tensor.node_name, ms_tensor.slot = tensor_name.rsplit(':', 1) + ms_tensor.truncate = truncate_tag + ms_tensor.iter = 'prev' if tensor_info.get('iter') else '' + + return view_event + + +def is_scope_type(node_type): + """Judge whether the type is scope type.""" + scope_types = [NodeTypeEnum.NAME_SCOPE.value, NodeTypeEnum.AGGREGATION_SCOPE.value] + return node_type in scope_types + + +def str_to_slice_or_int(input_str): + """ + Translate param from string to slice or int. + + Args: + input_str (str): The string to be translated. + + Returns: + Union[int, slice], the transformed param. + """ + try: + if ':' in input_str: + ret = slice(*map(lambda x: int(x.strip()) if x.strip() else None, input_str.split(':'))) + else: + ret = int(input_str) + except ValueError as err: + log.error("Failed to create slice from %s", input_str) + log.exception(err) + raise DebuggerParamValueError("Invalid shape.") + return ret diff --git a/mindinsight/debugger/debugger_cache.py b/mindinsight/debugger/debugger_cache.py new file mode 100644 index 00000000..dd56fd52 --- /dev/null +++ b/mindinsight/debugger/debugger_cache.py @@ -0,0 +1,154 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Implement the debugger data cache manager.""" +import sys + +from mindinsight.debugger.common.log import logger as log +from mindinsight.debugger.common.utils import Streams +from mindinsight.debugger.stream_handler import EventHandler, MetadataHandler, GraphHandler, \ + TensorHandler, WatchpointHandler, WatchpointHitHandler + +STREAM_HANDLER_MAP = { + Streams.COMMAND.value: EventHandler, + Streams.DATA.value: EventHandler, + Streams.METADATA.value: MetadataHandler, + Streams.GRAPH.value: GraphHandler, + Streams.TENSOR.value: TensorHandler, + Streams.WATCHPOINT.value: WatchpointHandler, + Streams.WATCHPOINT_HIT.value: WatchpointHitHandler +} + + +class DebuggerCache: + """The debugger data cache manager.""" + + def __init__(self): + self._stream_handler = {} + + def initialize(self): + """Initialize the stream handlers.""" + self._stream_handler = {} + for stream in Streams: + mode = stream.value + stream_handler = STREAM_HANDLER_MAP.get(mode) + self._stream_handler[mode] = stream_handler() + + def clean(self): + """Clean cache for all stream.""" + for _, stream_handler in self._stream_handler.items(): + stream_handler.clean() + + def get_stream_handler(self, mode): + """ + Get the stream handler object. + + Args: + mode (Streams): The type of stream handler. + + Returns: + StreamHandlerBase, the stream handler object. + """ + return self._stream_handler.get(mode.value) + + def _get(self, mode, pos): + """ + Get updated data or command from cache. + + Args: + mode (Streams): The type of info. `Streams.DATA` or `Streams.COMMAND`. + pos (int): The index of info. + + Returns: + object, the pos-th message about `mode` type of info. + """ + stream_handler = self.get_stream_handler(mode) + + return stream_handler.get(pos) + + def _put(self, mode, value): + """ + Set updated data or command from cache. + + Args: + mode (Streams): The type of info. `Streams.DATA` or `Streams.COMMAND`. + value (object): The info to be record in cache. + """ + stream_handler = self.get_stream_handler(mode) + + return stream_handler.put(value) + + def get_command(self, pos): + """ + Get the pos-th command in command stream. + + Args: + pos (int): The index of command. + + Returns: + int, the position of next message. + EventReply, the command object. + """ + content = self._get(Streams.COMMAND, pos) + next_pos = content.get('metadata').get('pos') + reply = content.get('cmd') + return next_pos, reply + + def put_command(self, cmd): + """ + Set command to command stream. + + Args: + cmd (EventReply): The command EventReply. + """ + log.debug("Set command %s", cmd) + return self._put(Streams.COMMAND, {'cmd': cmd}) + + def has_command(self, pos): + """Judge if the number of command is no less than `pos`.""" + event = self.get_stream_handler(Streams.COMMAND).has_pos(pos) + + return event + + def clean_command(self): + """Clean command queue.""" + self.get_stream_handler(Streams.COMMAND).clean() + log.debug("Clean command.") + + def clean_data(self): + """Clean command queue.""" + self.get_stream_handler(Streams.DATA).clean() + log.debug("Clean data queue.") + + def get_data(self, pos): + """ + Get updated data from data stream. + + Args: + pos (int): The index of data. + + Returns: + object, updated data_value. + """ + return self._get(Streams.DATA, pos) + + def put_data(self, value): + """ + Set updated data to data stream. + + Args: + value (dict): The updated data. + """ + log.debug("Set <%d> bytes data", sys.getsizeof(value)) + return self._put(Streams.DATA, value) diff --git a/mindinsight/debugger/debugger_grpc_server.py b/mindinsight/debugger/debugger_grpc_server.py new file mode 100644 index 00000000..c70d8228 --- /dev/null +++ b/mindinsight/debugger/debugger_grpc_server.py @@ -0,0 +1,309 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Implement the debugger grpc server.""" +from functools import wraps + +from mindinsight.debugger.common.log import logger as log +from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ + create_view_event_from_tensor_history, Streams +from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base +from mindinsight.debugger.proto.ms_graph_pb2 import GraphProto + + +def debugger_wrap(func): + """Wrapper for catch exception.""" + + @wraps(func) + def record_log(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as err: + log.exception(err) + raise err + + return record_log + + +class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): + """The grpc server used to interactive with grpc client.""" + + def __init__(self, cache_store): + """ + Initialize. + + Args: + cache_store (DebuggerCache): Debugger cache store. + """ + cache_store.initialize() + self._cache_store = cache_store + self._pos = None + self._status = None + self._view_event = None + self._view_round = None + self._continue_steps = None + self.init() + + def init(self): + """Init debugger grpc server.""" + self._pos = '0' + self._status = ServerStatus.PENDING + self._view_event = None + self._view_round = True + self._continue_steps = 0 + self._cache_store.clean() + + @debugger_wrap + def WaitCMD(self, request, context): + """Wait for a command in DebuggerCache.""" + # check if graph have already received. + log.info("Received WaitCMD at %s-th step.", request.cur_step) + if self._status == ServerStatus.PENDING: + log.warning("No graph received before WaitCMD.") + reply = get_ack_reply(1) + return reply + # send graph if has not been sent before + self._pre_process(request) + # deal with old command + reply = self._deal_with_old_command() + if reply: + log.info("Reply to WaitCMD with old command: %s", reply) + return reply + # send view cmd + if self._view_round and self._view_event: + self._view_round = False + reply = self._view_event + log.debug("Send ViewCMD.") + # continue multiple steps training + elif self._continue_steps != 0: + reply = get_ack_reply() + reply.run_cmd.run_steps = 1 + reply.run_cmd.run_level = 'step' + self._continue_steps = self._continue_steps - 1 if self._continue_steps > 0 else -1 + self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() + log.debug("Send RunCMD. Clean watchpoint hit.") + # wait for command + else: + reply = self._wait_for_next_command() + + if reply is None: + reply = get_ack_reply(1) + log.warning("Failed to get command event.") + else: + log.info("Reply to WaitCMD: %s", reply) + return reply + + def _pre_process(self, request): + """Send graph and metadata when WaitCMD first called.""" + metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) + if self._status == ServerStatus.RECEIVE_GRAPH: + self._status = ServerStatus.WAITING + metadata_stream.state = 'waiting' + metadata = metadata_stream.get() + self._cache_store.clean_command() + res = self._cache_store.get_stream_handler(Streams.GRAPH).get() + res.update(metadata) + self._cache_store.put_data(res) + log.info("Put graph into data queue.") + + if metadata_stream.step < request.cur_step or metadata_stream.full_name != request.cur_node: + # clean tensor cache and DataQueue at the beginning of each step + self._update_metadata(metadata_stream, request) + + def _update_metadata(self, metadata_stream, metadata_proto): + """Update metadata.""" + # reset view round and clean cache data + self._view_round = True + if metadata_stream.step < metadata_proto.cur_step: + self._cache_store.clean_data() + self._cache_store.get_stream_handler(Streams.TENSOR).clean_tensors( + metadata_proto.cur_step) + # put new metadata into cache + metadata_stream.put(metadata_proto) + cur_node = self._cache_store.get_stream_handler(Streams.GRAPH).get_node_name_by_full_name( + metadata_proto.cur_node) if metadata_proto.cur_node else '' + metadata_stream.node_name = cur_node + metadata = metadata_stream.get() + self._cache_store.put_data(metadata) + log.info("Put new metadata into data queue.") + + def _deal_with_old_command(self): + """Deal with old command.""" + event = None + while self._cache_store.has_command(self._pos) and event is None: + event = self._get_next_command() + log.debug("Deal with old %s-th command:\n%s.", self._pos, event) + + return event + + def _wait_for_next_command(self): + """ + Wait for next command. + + Returns: + EventReply, the command event. + """ + log.info("Start to wait for command.") + self._cache_store.get_stream_handler(Streams.METADATA).state = 'waiting' + self._cache_store.put_data({'metadata': {'state': 'waiting'}}) + event = None + while event is None and self._status == ServerStatus.WAITING: + log.debug("Wait for %s-th command", self._pos) + event = self._get_next_command() + return event + + def _get_next_command(self): + """Get next command.""" + self._pos, event = self._cache_store.get_command(self._pos) + log.debug("Received event :%s", event) + if event is None: + return event + if isinstance(event, dict) and event.get('reset'): + self._set_view_event(event) + event = None + elif event.HasField('run_cmd'): + event = self._deal_with_run_cmd(event) + elif event.HasField('view_cmd'): + self._view_round = False + elif event.HasField('exit'): + self._cache_store.clean() + log.info("Clean cache for exit cmd.") + + return event + + def _deal_with_run_cmd(self, event): + """Deal with run cmd.""" + run_cmd = event.run_cmd + # receive step command + if run_cmd.run_level == 'step': + # receive pause cmd + if run_cmd.run_steps == 0: + log.debug("Pause training and wait for next command.") + self._continue_steps = 0 + return None + # receive step cmd + self._continue_steps = run_cmd.run_steps - 1 + event.run_cmd.run_steps = 1 + self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).clean() + log.debug("Receive RunCMD. Clean watchpoint hit cache.") + + return event + + def _set_view_event(self, event): + """Create view event for view cmd.""" + # the first tensor in view cmd is always the output + node_name = event.get('node_name') + tensor_history = event.get('tensor_history') + if not node_name or not tensor_history: + self._view_event = None + log.info("Reset view command to None.") + else: + # create view event and set + self._view_event = create_view_event_from_tensor_history(tensor_history) + log.info("Reset view command to %s.", node_name) + + @debugger_wrap + def SendMetadata(self, request, context): + """Send metadata into DebuggerCache.""" + log.info("Received Metadata.") + if self._status != ServerStatus.PENDING: + log.info("Re-initialize cache store when new session comes.") + self.init() + + client_ip = context.peer().split(':', 1)[-1] + metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) + metadata_stream.put(request) + metadata_stream.client_ip = client_ip + metadata = metadata_stream.get() + # put metadata into data queue + self._cache_store.put_data(metadata) + log.info("Put new metadata to DataQueue.") + reply = get_ack_reply() + log.info("Send the reply to %s.", client_ip) + return reply + + @debugger_wrap + def SendGraph(self, request_iterator, context): + """Send graph into DebuggerCache.""" + log.info("Received graph.") + serial_graph = b"" + for chunk in request_iterator: + serial_graph += chunk.buffer + graph = GraphProto.FromString(serial_graph) + log.debug("Deserialize the graph. Receive %s nodes", len(graph.node)) + self._cache_store.get_stream_handler(Streams.GRAPH).put(graph) + self._cache_store.get_stream_handler(Streams.TENSOR).put_const_vals(graph.const_vals) + self._status = ServerStatus.RECEIVE_GRAPH + reply = get_ack_reply() + log.info("Send the reply for graph.") + return reply + + @debugger_wrap + def SendTensors(self, request_iterator, context): + """Send tensors into DebuggerCache.""" + log.info("Received tensor.") + tensor_construct = [] + tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR) + metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) + tensor_names = [] + step = metadata_stream.step + for tensor in request_iterator: + tensor_construct.append(tensor) + if tensor.finished: + tensor_stream.put({'step': step, 'tensor_protos': tensor_construct}) + tensor_construct = [] + tensor_names.append(':'.join([tensor.node_name, tensor.slot])) + continue + # send back tensor finished flag when all waiting tensor has value. + tensor_history = tensor_stream.get_tensor_history(tensor_names) + self._add_node_name_for_tensor_history(tensor_history) + metadata = metadata_stream.get() + tensor_history.update(metadata) + self._cache_store.put_data({}) # reply to the listening request + self._cache_store.put_data(tensor_history) + log.info("Send updated tensor history to data queue.") + reply = get_ack_reply() + return reply + + def _add_node_name_for_tensor_history(self, tensor_history): + """Add node name for tensor history.""" + graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) + for tensor_info in tensor_history.get('tensor_history'): + if tensor_info: + full_name, slot = tensor_info.get('full_name', '').rsplit(':', 1) + node_name = graph_stream.get_node_name_by_full_name(full_name) + tensor_info['name'] = node_name + ':' + slot + + @debugger_wrap + def SendWatchpointHits(self, request_iterator, context): + """Send watchpoint hits info DebuggerCache.""" + log.info("Received WatchpointHits. Left steps %d change to 0.", self._continue_steps) + self._continue_steps = 0 + self._view_event = None + watchpoint_hit_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) + watchpoint_stream = self._cache_store.get_stream_handler(Streams.WATCHPOINT) + graph_stream = self._cache_store.get_stream_handler(Streams.GRAPH) + for watchpoint_hit_proto in request_iterator: + watchpoint_hit = { + 'tensor_proto': watchpoint_hit_proto.tensor, + 'watchpoint': watchpoint_stream.get_watchpoint_by_id(watchpoint_hit_proto.id), + 'node_name': graph_stream.get_node_name_by_full_name( + watchpoint_hit_proto.tensor.node_name) + } + watchpoint_hit_stream.put(watchpoint_hit) + watchpoint_hits_info = watchpoint_hit_stream.get() + self._cache_store.put_data(watchpoint_hits_info) + log.info("Send the watchpoint hits to DataQueue.\nSend the reply.") + reply = get_ack_reply() + return reply diff --git a/mindinsight/debugger/debugger_server.py b/mindinsight/debugger/debugger_server.py new file mode 100644 index 00000000..e32fae86 --- /dev/null +++ b/mindinsight/debugger/debugger_server.py @@ -0,0 +1,752 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Implement the debugger server.""" +import signal +from concurrent import futures +from threading import Thread + +import grpc + +from mindinsight.conf import settings +from mindinsight.datavisual.data_transform.graph import NodeTypeEnum +from mindinsight.datavisual.utils.tools import to_float +from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ + DebuggerParamTypeError, DebuggerCreateWatchPointError, DebuggerUpdateWatchPointError, \ + DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, DebuggerCompareTensorError +from mindinsight.debugger.common.log import logger as log +from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \ + create_view_event_from_tensor_history, Streams, is_scope_type, NodeBasicInfo, \ + str_to_slice_or_int +from mindinsight.debugger.debugger_cache import DebuggerCache +from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer +from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base +from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD +from mindinsight.utils.exceptions import MindInsightException + + +class DebuggerServer: + """The server manager of debugger.""" + + def __init__(self, grpc_port=None): + self.grpc_port = grpc_port + self.cache_store = DebuggerCache() + self.grpc_server = DebuggerGrpcServer(self.cache_store) + self.grpc_server_manager = None + self.back_server = None + self._watch_point_id = 0 + + def start(self): + """Start server.""" + grpc_port = self.grpc_port if self.grpc_port else "50051" + host = settings.HOST if hasattr(settings, 'HOST') else '[::]' + hostname = "{}:{}".format(host, grpc_port) + # initialize a grpc server + grpc_server_manager = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + grpc_server_base.add_EventListenerServicer_to_server(self.grpc_server, grpc_server_manager) + grpc_server_manager.add_insecure_port(hostname) + grpc_server_manager.start() + my_server_thread = Thread(target=grpc_server_manager.wait_for_termination) + # start grpc server + my_server_thread.start() + self.back_server = my_server_thread + self.grpc_server_manager = grpc_server_manager + # register stop server handler + signal.signal(signal.SIGINT, self._stop_handler) + log.info("Start grpc server %s", hostname) + + def _stop_handler(self, signum, frame): + """Register stop server handler.""" + self.stop() + log.debug("Deal with stop signal: %s, %s", signum, frame) + + def stop(self): + """Stop debugger server.""" + self.grpc_server_manager.stop(grace=None) + self.back_server.join() + log.info("Stop debugger server.") + + def poll_data(self, pos): + """ + Get the pos-th data from DebuggerCache. + + Args: + pos (int): The index of data. + + Returns: + dict, the data to be updated. + """ + if not isinstance(pos, str): + log.error("Pos should be string. Received: %s", pos) + raise DebuggerParamValueError("Pos should be string.") + + reply = self.cache_store.get_data(pos) + + return reply + + def search(self, name, watch_point_id): + """Search for single node in graph.""" + log.info("receive search request for node:%s, in watchpoint:%d", name, watch_point_id) + graph = self.cache_store.get_stream_handler(Streams.GRAPH).search_nodes(name) + self.cache_store.get_stream_handler(Streams.WATCHPOINT).set_watch_nodes( + graph, watch_point_id) + return graph + + def tensor_comparisons(self, name, shape, detail='data', tolerance='0'): + """ + Get tensor comparisons data for given name, detail, shape and tolerance. + + Args: + name (str): The name of tensor for ui. + detail (str): Specify which data to query. Current available value is 'data' which means + concrete tensor data. Histogram or unique count can be supported in the future. + shape (str): Specify concrete dimensions of shape. + tolerance (str): Specify tolerance of difference between current step tensor and previous + step tensor. Default value is 0. + + Raises: + DebuggerParamValueError, If node type is not parameter or value of detail is not support. + DebuggerCompareTensorError, If MindSpore is not in waiting state. + Returns: + dict, the retrieved data. + """ + if self.cache_store.get_stream_handler( + Streams.METADATA).state != ServerStatus.WAITING.value: + log.error("Failed to compare tensors as the MindSpore is not in waiting state.") + raise DebuggerCompareTensorError( + "Failed to compare tensors as the MindSpore is not in waiting state." + ) + self.validate_tensor_param(name, detail) + parsed_shape = self.parse_shape(shape) + node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name) + tolerance = to_float(tolerance, 'tolerance') + tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) + if detail == 'data': + if node_type == NodeTypeEnum.PARAMETER.value: + reply = tensor_stream.get_tensors_diff(tensor_name, parsed_shape, tolerance) + else: + raise DebuggerParamValueError("The node type must be parameter, but got {}.".format(node_type)) + else: + raise DebuggerParamValueError("The value of detail: {} is not support.".format(detail)) + return reply + + def retrieve(self, mode, filter_condition=None): + """ + Retrieve data according to mode and params. + + Args: + mode (str): The type of info message. + filter_condition (dict): The filter condition. + + Returns: + dict, the retrieved data. + """ + log.info("receive retrieve request for mode:%s\n, filter_condition: %s", mode, + filter_condition) + # validate watchpoint_id + + mode_mapping = { + 'all': self._retrieve_all, + 'node': self._retrieve_node, + 'watchpoint': self._retrieve_watchpoint, + 'watchpoint_hit': self._retrieve_watchpoint_hit + } + # validate param + if mode not in mode_mapping.keys(): + log.error("Invalid param . should be in ['all', 'node', 'watchpoint', " + "'watchpoint_hit', 'tensor'], but got %s.", mode_mapping) + raise DebuggerParamTypeError("Invalid mode.") + filter_condition = {} if filter_condition is None else filter_condition + self._watch_point_id = filter_condition.get('watch_point_id', 0) + reply = mode_mapping[mode](filter_condition) + + return reply + + def _retrieve_all(self, filter_condition=None): + """Retrieve metadata, root graph and watchpoint list.""" + if filter_condition: + log.error("No filter condition required for retrieve all request.") + raise DebuggerParamTypeError("filter_condition should be empty.") + result = {} + self.cache_store.clean_data() + log.info("Clean data queue cache when retrieve all request.") + self.cache_store.put_command({'reset': True}) + for stream in [Streams.METADATA, Streams.GRAPH, Streams.WATCHPOINT]: + sub_res = self.cache_store.get_stream_handler(stream).get() + result.update(sub_res) + + return result + + def _retrieve_node(self, filter_condition): + """ + Retrieve node info. + + Args: + filter_condition (dict): Filter condition. + + - name (str): The name of single node. + + - watch_point_id (int): The id of watchpoint. + + - single_node (bool): If False, return the sub-layer of single node. If True, return + the node list from root node to single node. + + Returns: + dict, the node info. + """ + log.info("Retrieve node %s.", filter_condition) + graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) + # validate parameters + node_name = filter_condition.get('name') + if not node_name: + node_type = NodeTypeEnum.NAME_SCOPE.value + else: + node_type = graph_stream.get_node_type(node_name) + filter_condition['node_type'] = node_type + filter_condition['single_node'] = bool(filter_condition.get('single_node')) + # get graph for scope node + if is_scope_type(node_type): + reply = self._get_nodes_info(filter_condition) + # get tensor history for leaf node + else: + reply = self._get_tensor_history(node_name) + if filter_condition.get('single_node'): + graph = self._get_nodes_info(filter_condition) + reply.update(graph) + return reply + + def _get_nodes_info(self, filter_condition): + """ + Get nodes info. + + Args: + filter_condition (dict): The filter condition. + + - name (str): The node name. + + - single_node (bool): If False, return the sub-layer of single node. If True, return + the node list from root node to single node. + + - watch_point_id (int): The id of watchpoint. + + Returns: + dict, reply with graph. + """ + # get graph + graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) + reply = graph_stream.get(filter_condition) + graph = reply.get('graph') + # add watched label + self.cache_store.get_stream_handler(Streams.WATCHPOINT).set_watch_nodes( + graph, self._watch_point_id) + return reply + + def retrieve_tensor_history(self, node_name): + """ + Retrieve tensor history for leaf node. + + Args: + node_name (str): The name of leaf node. + + Returns: + dict, the tensor history and metadata. + """ + log.info("Retrieve tensor history for node: %s.", node_name) + self._validate_leaf_name(node_name) + res = self._get_tensor_history(node_name) + return res + + def _validate_leaf_name(self, node_name): + """Validate if the node is a leaf node.""" + graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) + node_type = graph_stream.get_node_type(node_name) + if is_scope_type(node_type): + log.error("Scope type node has no tensor history.") + raise DebuggerParamValueError("Invalid leaf node name.") + + def _get_tensor_history(self, node_name): + """ + Get tensor history for single node. + + Args: + node_name (str): The name of leaf node. + + Returns: + dict, the tensor history and metadata. + """ + # get basic tensor history + graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) + tensor_history = graph_stream.get_tensor_history(node_name) + # set the view event + self.cache_store.put_command( + {'reset': True, + 'node_name': node_name, + 'tensor_history': tensor_history.get('tensor_history')}) + # add tensor value for tensor history + self._add_tensor_value_for_tensor_history(tensor_history) + # add hit label for tensor history + watchpoint_hit_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT) + watchpoint_hit_stream.update_tensor_history(tensor_history) + # add metadata + metadata = self.cache_store.get_stream_handler(Streams.METADATA).get() + tensor_history.update(metadata) + return tensor_history + + def _add_tensor_value_for_tensor_history(self, tensor_history): + """ + Add tensor value for_tensor_history and send ViewCMD if tensor value missed. + + Args: + tensor_history (list[dict]): A list of tensor info, including name and type. + + Returns: + dict, the tensor info. + """ + tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR) + missed_tensors = tensor_stream.update_tensor_history(tensor_history) + if missed_tensors: + view_cmd = create_view_event_from_tensor_history(missed_tensors) + self.cache_store.put_command(view_cmd) + log.debug("Send view cmd.") + + def retrieve_tensor_value(self, name, detail, shape): + """Retrieve the tensor value.""" + log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape) + self.validate_tensor_param(name, detail) + parsed_shape = self.parse_shape(shape) + node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name) + reply = self.cache_store.get_stream_handler(Streams.TENSOR).get( + {'name': tensor_name, + 'node_type': node_type, + 'shape': parsed_shape} + ) + reply['tensor_value']['name'] = name + + return reply + + def _get_tensor_name_and_type_by_ui_name(self, name): + """ + Get inner tensor name and type by UI name. + + Args: + name (str): Node name shown in UI. + + Returns: + str, full name of tensor. + str, node type of tensor. + """ + node_name, slot = name.rsplit(':', 1) + graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) + node_type = graph_stream.get_node_type(node_name) + full_name = graph_stream.get_full_name(node_name) + tensor_name = full_name + ':' + slot + return node_type, tensor_name + + @staticmethod + def validate_tensor_param(name, detail): + """Validate params for retrieve tensor request.""" + # validate name + if not isinstance(name, str) or ':' not in name: + log.error("Invalid tensor name. Received: %s", name) + raise DebuggerParamValueError("Invalid tensor name.") + # validate data + if detail != 'data': + log.error("Invalid detail value. Received: %s", detail) + raise DebuggerParamValueError("Invalid detail value.") + + @staticmethod + def parse_shape(shape): + """Parse shape.""" + if shape is None: + return shape + if not (isinstance(shape, str) and shape.startswith('[') and shape.endswith(']')): + log.error("Invalid shape. Received: %s", shape) + raise DebuggerParamValueError("Invalid shape.") + shape = shape.strip('[]') + if shape.count(':') > 2: + log.error("Invalid shape. At most two dimensions are specified.") + raise DebuggerParamValueError("Invalid shape.") + parsed_shape = tuple( + str_to_slice_or_int(dim) for dim in shape.split(',')) if shape else tuple() + log.info("Parsed shape: %s from %s", parsed_shape, shape) + return parsed_shape + + def _retrieve_watchpoint(self, filter_condition): + """ + Retrieve watchpoint. + + Args: + filter_condition (dict): Filter condition. + + - watch_point_id (int): The id of watchoint. If not given, return all watchpoints. + + - name (str): The name of single node. + + - single_node (bool): If False, return the sub-layer of single node. If True, return + the node list from root node to single node. + + Returns: + dict, watch point list or relative graph. + """ + watchpoint_id = filter_condition.get('watch_point_id') + if watchpoint_id is None: + reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT).get() + log.debug("Get condition of watchpoints.") + else: + reply = self._retrieve_node(filter_condition) + log.debug("Get graph of %d-th watchpoint.", watchpoint_id) + + return reply + + def _retrieve_watchpoint_hit(self, filter_condition): + """ + Retrieve watchpoint hit. + + Args: + filter_condition (dict): Filter condition. + + - name (str): The name of single node. + + - single_node (bool): If False, return the sub-layer of single node. If True, return + the node list from root node to single node. + + Returns: + dict, watch point list or relative graph. + """ + node_name = filter_condition.get('name') + # get watchpoint hit list + if node_name is None: + reply = self.cache_store.get_stream_handler(Streams.WATCHPOINT_HIT).get() + return reply + + self._validate_leaf_name(node_name) + # get tensor history + reply = self._get_tensor_history(node_name) + log.debug("Get tensor history for watchpoint hit node.") + # get single graph + if filter_condition.get('single_node'): + graph = self._get_nodes_info(filter_condition) + reply.update(graph) + log.debug("Get tensor history for watchpoint hit node.") + + return reply + + def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None): + """ + Create watchpoint. + + Args: + watch_condition (dict): The watch condition. + + - condition (str): Accept `INF` or `NAN`. + + - param (list[float]): Not defined yet. + watch_nodes (list[str]): The list of node names. + watch_point_id (int): The id of watchpoint. + + Returns: + dict, the id of new watchpoint. + """ + log.info("Received create watchpoint request. WatchCondition: %s", watch_condition) + metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) + if metadata_stream.state != ServerStatus.WAITING.value: + log.error("Failed to create watchpoint as the MindSpore is not in waiting state.") + raise DebuggerCreateWatchPointError( + "Failed to create watchpoint as the MindSpore is not in waiting state." + ) + if metadata_stream.backend == 'GPU' and watch_condition.get('condition') == 'OVERFLOW': + log.error("GPU doesn't support OVERFLOW watch condition.") + raise DebuggerParamValueError("GPU doesn't support OVERFLOW watch condition.") + + watch_nodes = self._get_node_basic_infos(watch_nodes) + watch_point_id = self.cache_store.get_stream_handler(Streams.WATCHPOINT).create_watchpoint( + watch_condition, watch_nodes, watch_point_id) + log.info("Create watchpoint %d", watch_point_id) + return {'id': watch_point_id} + + def update_watchpoint(self, watch_point_id, watch_nodes, mode, name=None): + """ + Update watchpoint. + + Args: + watch_point_id (int): The id of watchpoint. + watch_nodes (list[str]): The list of node names. + mode (int): The update operator on nodes. 0 for remove nodes from watch nodes. + 1 for add nodes to watch nodes. + name (str): The search name. Default: None. + + Returns: + dict, empty response. + """ + if self.cache_store.get_stream_handler( + Streams.METADATA).state != ServerStatus.WAITING.value: + log.error("Failed to update watchpoint as the MindSpore is not in waiting state.") + raise DebuggerUpdateWatchPointError( + "Failed to update watchpoint as the MindSpore is not in waiting state." + ) + # validate + if not watch_nodes or not watch_point_id: + log.error("Invalid parameter for update watchpoint.") + raise DebuggerParamValueError("Invalid parameter for update watchpoint.") + # update watch node + if name is not None: + watch_nodes = self._get_watch_nodes_by_search(watch_nodes) + elif mode == 1: + watch_nodes = self._get_node_basic_infos(watch_nodes) + + self.cache_store.get_stream_handler(Streams.WATCHPOINT).update_watchpoint( + watch_point_id, watch_nodes, mode) + log.info("Update watchpoint with id: %d", watch_point_id) + return {} + + def _get_watch_nodes_by_search(self, watch_nodes): + """Get watched leaf nodes by search name.""" + watched_leaf_nodes = [] + graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) + for search_name in watch_nodes: + search_nodes = graph_stream.get_searched_node_list() + search_node_names = [ + NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type) + for node in search_nodes + if node.name.startswith(search_name)] + watched_leaf_nodes.extend(search_node_names) + + log.debug("Update nodes: %s", watched_leaf_nodes) + + return watched_leaf_nodes + + def delete_watchpoint(self, watch_point_id): + """ + Delete watchpoint. + + Args: + watch_point_id (int): The id of watchpoint. + + Returns: + dict, empty response. + """ + if self.cache_store.get_stream_handler( + Streams.METADATA).state != ServerStatus.WAITING.value: + log.error("Failed to delete watchpoint as the MindSpore is not in waiting state.") + raise DebuggerDeleteWatchPointError( + "Failed to delete watchpoint as the MindSpore is not in waiting state." + ) + self.cache_store.get_stream_handler(Streams.WATCHPOINT).delete_watchpoint( + watch_point_id) + log.info("Delete watchpoint with id: %d", watch_point_id) + return {} + + def _get_node_basic_infos(self, node_names): + """Get node info according to node names.""" + if not node_names: + return [] + graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) + node_infos = [] + for node_name in node_names: + node_type = graph_stream.get_node_type(node_name) + # optimizer later + if node_type == NodeTypeEnum.AGGREGATION_SCOPE.value: + sub_nodes = graph_stream.get_nodes(node_name) + sub_infos = [NodeBasicInfo(name=node.name, full_name=node.full_name, type=node.type) + for node in sub_nodes] + node_infos.extend(sub_infos) + continue + full_name = graph_stream.get_full_name(node_name) + node_infos.append(NodeBasicInfo(name=node_name, full_name=full_name, type=node_type)) + return node_infos + + def control(self, params=None): + """ + Control the training process. + + Args: + params (dict): The control params. + + - mode (str): Acceptable control command, including `continue`, + `pause` and `terminate`. + + - level (str): The control granularity, `node` level or `step` level. + Default: `step`. + + - steps (int): Specify the steps that training should run. + Used when `level` is `step`. + + - name (str): Specify the name of the node. Used when `level` is `node`. + + Returns: + dict, the response. + """ + log.info("Receive control request: %s.", params) + mode = params.get('mode') + metadata_stream = self.cache_store.get_stream_handler(Streams.METADATA) + if mode == 'continue': + reply = self._continue(metadata_stream, params) + elif mode in ['pause', 'terminate']: + mode_mapping = { + 'pause': self._pause, + 'terminate': self._terminate + } + reply = mode_mapping.get(mode)(metadata_stream) + else: + log.error("Invalid control mode %s", mode) + raise DebuggerParamValueError("Invalid control mode.") + + return reply + + def _continue(self, metadata_stream, params): + """ + Send RunCMD to MindSpore. + + Args: + metadata_stream (MetadataHandler): The metadata_handler + params (dict): The control params. + """ + if metadata_stream.state != ServerStatus.WAITING.value: + log.error("MindSpore is not ready to run. Current state is: %s", metadata_stream.state) + raise DebuggerContinueError( + "MindSpore is not ready to run or is running currently." + ) + metadata_stream.state = ServerStatus.RUNNING.value + current_state = ServerStatus.RUNNING.value + try: + event = self._construct_run_event(params) + self._send_watchpoints() + self.cache_store.put_command(event) + except MindInsightException as err: + log.error("Failed to send run event.") + log.exception(err) + current_state = ServerStatus.WAITING.value + metadata_stream.state = current_state + raise DebuggerContinueError("Failed to send run command.") + else: + log.debug("Send the RunCMD to command queue.") + + return {'metadata': {'state': current_state}} + + def _validate_node_type(self, node_name): + """Check the node type in node control.""" + if not node_name: + return + node_type = self.cache_store.get_stream_handler(Streams.GRAPH).get_node_type(node_name) + unsupported_types = [item.value for item in list(NodeTypeEnum)] + if node_type in unsupported_types: + log.error("Invalid node type. %s", node_name) + raise DebuggerParamValueError(f"The type of node {node_name} is unsupported for " + "continue to command.") + + def _construct_run_event(self, params): + """ + Construct run cmd from input control params. + + Args: + params (dict): The control params. + + - level (str): The control granularity, `node` level or `step` level. + Default: `step`. + + - steps (int): Specify the steps that training should run. + Used when `level` is `step`. + + - full_name (str): Specify the name of the node. Used when `level` is `node`. + + Returns: + EventReply, control event with run command. + """ + level = params.get('level', 'step') + event = get_ack_reply() + if level == 'step': + steps = params.get('steps') + if not steps: + steps = 1 + run_cmd = RunCMD(run_level='step', run_steps=steps) + elif level == 'node': + self._validate_node_type(params.get('name')) + name = self.cache_store.get_stream_handler(Streams.GRAPH).get_full_name( + params['name']) + if not name: + name = '' + run_cmd = RunCMD(run_level='node', node_name=name) + else: + log.error("Invalid Value. `level` should be `step` or `node`. Got %s", level) + raise DebuggerParamValueError("level` should be `step` or `node`") + + event.run_cmd.CopyFrom(run_cmd) + log.debug("Construct run event. %s", event) + return event + + def _send_watchpoints(self): + """Set watchpoints.""" + watchpoint_stream = self.cache_store.get_stream_handler(Streams.WATCHPOINT) + watchpoints = watchpoint_stream.get(filter_condition=True).get('watch_points') + if watchpoints: + for watchpoint in watchpoints: + event = get_ack_reply() + event.set_cmd.CopyFrom(watchpoint) + self.cache_store.put_command(event) + watchpoint_stream.sync_set_cmd() + log.debug("Send SetCMD to MindSpore. %s", event) + + def _pause(self, metadata_stream): + """ + Pause the training. + + Args: + metadata_stream (MetadataHandler): The metadata stream handler. + """ + if metadata_stream.state != ServerStatus.RUNNING.value: + log.error("The MindSpore is not running.") + raise DebuggerPauseError("The MindSpore is not running.") + metadata_stream.state = 'waiting' + event = get_ack_reply() + event.run_cmd.CopyFrom(RunCMD(run_level='step', run_steps=0)) + self.cache_store.put_command(event) + log.debug("Send the Pause command") + return {'metadata': {'state': 'waiting'}} + + def _terminate(self, metadata_stream): + """ + Terminate the training. + + Args: + metadata_stream (MetadataHandler): The metadata stream handler. + """ + metadata_stream.state = 'pending' + event = get_ack_reply() + event.exit = True + self.cache_store.put_command(event) + log.debug("Send the ExitCMD.") + return {'metadata': {'state': 'pending'}} + + def retrieve_node_by_bfs(self, node_name, ascend=False): + """Get the graph and tensor history of the next node name according to node_name.""" + log.info("Retrieve node <%s> by bfs, `ascend` is :%s", + node_name, ascend) + reply = {} + graph_stream = self.cache_store.get_stream_handler(Streams.GRAPH) + next_node_name = graph_stream.get_node_by_bfs_order(node_name, ascend) + # no next node + if next_node_name is None: + return reply + # add graph and tensor history for next node + filter_condition = { + 'name': next_node_name, + 'single_node': True + } + search_graph = self._get_nodes_info(filter_condition) + tensor_history = self._get_tensor_history(next_node_name) + reply = {'name': next_node_name} + reply.update(search_graph) + reply.update(tensor_history) + + return reply diff --git a/mindinsight/debugger/proto/debug_grpc.proto b/mindinsight/debugger/proto/debug_grpc.proto new file mode 100644 index 00000000..562241d9 --- /dev/null +++ b/mindinsight/debugger/proto/debug_grpc.proto @@ -0,0 +1,113 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package debugger; + +import "mindinsight/debugger/proto/ms_graph.proto"; + + +service EventListener { + rpc WaitCMD (Metadata) returns (EventReply) {}; + rpc SendMetadata (Metadata) returns (EventReply) {}; + rpc SendGraph (stream Chunk) returns (EventReply) {}; + rpc SendTensors (stream TensorProto) returns (EventReply) {}; + rpc SendWatchpointHits (stream WatchpointHit) returns (EventReply) {}; +} + +message Metadata { + string device_name = 1; + int32 cur_step = 2; + // define the backend is 'GPU' or 'Ascend' + string backend = 3; + // the full name of current node + string cur_node = 4; + // check if training is done. + bool training_done = 5; +} + +message Chunk { + bytes buffer = 1; +} +message EventReply { + enum Status { + OK = 0; + FAILED = 1; + PENDING = 2; + } + + Status status = 1; + + oneof cmd { + bool exit = 2; + RunCMD run_cmd = 3; + SetCMD set_cmd = 4; + ViewCMD view_cmd = 5; + } +} + +message RunCMD { + // running level. 'step' or 'node' + string run_level = 1; + + oneof cmd { + int32 run_steps = 2; + + // the full name of next node + string node_name = 3; + } +} + +message SetCMD { + repeated WatchNode watch_nodes = 1; + WatchCondition watch_condition = 2; + bool delete = 3; + int32 id = 4; +} + +message ViewCMD { + repeated TensorProto tensors = 1; +} + +message WatchCondition { + enum Condition { + nan = 0; + inf = 1; + overflow = 2; + max_gt = 3; + max_lt = 4; + min_gt = 5; + min_lt = 6; + max_min_gt = 7; + max_min_lt = 8; + mean_gt = 9; + mean_lt = 10; + } + Condition condition = 1; + float value = 2; // for between condition, there will be two values +} + +message WatchNode { + string node_name = 1; + string node_type = 2; +} + +message WatchpointHit { + TensorProto tensor = 1; + WatchCondition watch_condition = 2; + int32 id = 3; +} diff --git a/mindinsight/debugger/proto/debug_grpc_pb2.py b/mindinsight/debugger/proto/debug_grpc_pb2.py new file mode 100644 index 00000000..f75175a1 --- /dev/null +++ b/mindinsight/debugger/proto/debug_grpc_pb2.py @@ -0,0 +1,683 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: mindinsight/debugger/proto/debug_grpc.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from mindinsight.debugger.proto import ms_graph_pb2 as mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2 + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='mindinsight/debugger/proto/debug_grpc.proto', + package='debugger', + syntax='proto3', + serialized_options=None, + serialized_pb=b'\n+mindinsight/debugger/proto/debug_grpc.proto\x12\x08\x64\x65\x62ugger\x1a)mindinsight/debugger/proto/ms_graph.proto\"k\n\x08Metadata\x12\x13\n\x0b\x64\x65vice_name\x18\x01 \x01(\t\x12\x10\n\x08\x63ur_step\x18\x02 \x01(\x05\x12\x0f\n\x07\x62\x61\x63kend\x18\x03 \x01(\t\x12\x10\n\x08\x63ur_node\x18\x04 \x01(\t\x12\x15\n\rtraining_done\x18\x05 \x01(\x08\"\x17\n\x05\x43hunk\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\"\xec\x01\n\nEventReply\x12+\n\x06status\x18\x01 \x01(\x0e\x32\x1b.debugger.EventReply.Status\x12\x0e\n\x04\x65xit\x18\x02 \x01(\x08H\x00\x12#\n\x07run_cmd\x18\x03 \x01(\x0b\x32\x10.debugger.RunCMDH\x00\x12#\n\x07set_cmd\x18\x04 \x01(\x0b\x32\x10.debugger.SetCMDH\x00\x12%\n\x08view_cmd\x18\x05 \x01(\x0b\x32\x11.debugger.ViewCMDH\x00\")\n\x06Status\x12\x06\n\x02OK\x10\x00\x12\n\n\x06\x46\x41ILED\x10\x01\x12\x0b\n\x07PENDING\x10\x02\x42\x05\n\x03\x63md\"L\n\x06RunCMD\x12\x11\n\trun_level\x18\x01 \x01(\t\x12\x13\n\trun_steps\x18\x02 \x01(\x05H\x00\x12\x13\n\tnode_name\x18\x03 \x01(\tH\x00\x42\x05\n\x03\x63md\"\x81\x01\n\x06SetCMD\x12(\n\x0bwatch_nodes\x18\x01 \x03(\x0b\x32\x13.debugger.WatchNode\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\x0e\n\x06\x64\x65lete\x18\x03 \x01(\x08\x12\n\n\x02id\x18\x04 \x01(\x05\"1\n\x07ViewCMD\x12&\n\x07tensors\x18\x01 \x03(\x0b\x32\x15.debugger.TensorProto\"\xee\x01\n\x0eWatchCondition\x12\x35\n\tcondition\x18\x01 \x01(\x0e\x32\".debugger.WatchCondition.Condition\x12\r\n\x05value\x18\x02 \x01(\x02\"\x95\x01\n\tCondition\x12\x07\n\x03nan\x10\x00\x12\x07\n\x03inf\x10\x01\x12\x0c\n\x08overflow\x10\x02\x12\n\n\x06max_gt\x10\x03\x12\n\n\x06max_lt\x10\x04\x12\n\n\x06min_gt\x10\x05\x12\n\n\x06min_lt\x10\x06\x12\x0e\n\nmax_min_gt\x10\x07\x12\x0e\n\nmax_min_lt\x10\x08\x12\x0b\n\x07mean_gt\x10\t\x12\x0b\n\x07mean_lt\x10\n\"1\n\tWatchNode\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x11\n\tnode_type\x18\x02 \x01(\t\"u\n\rWatchpointHit\x12%\n\x06tensor\x18\x01 \x01(\x0b\x32\x15.debugger.TensorProto\x12\x31\n\x0fwatch_condition\x18\x02 \x01(\x0b\x32\x18.debugger.WatchCondition\x12\n\n\x02id\x18\x03 \x01(\x05\x32\xc3\x02\n\rEventListener\x12\x35\n\x07WaitCMD\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12:\n\x0cSendMetadata\x12\x12.debugger.Metadata\x1a\x14.debugger.EventReply\"\x00\x12\x36\n\tSendGraph\x12\x0f.debugger.Chunk\x1a\x14.debugger.EventReply\"\x00(\x01\x12>\n\x0bSendTensors\x12\x15.debugger.TensorProto\x1a\x14.debugger.EventReply\"\x00(\x01\x12G\n\x12SendWatchpointHits\x12\x17.debugger.WatchpointHit\x1a\x14.debugger.EventReply\"\x00(\x01\x62\x06proto3' + , + dependencies=[mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.DESCRIPTOR,]) + + + +_EVENTREPLY_STATUS = _descriptor.EnumDescriptor( + name='Status', + full_name='debugger.EventReply.Status', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='OK', index=0, number=0, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='FAILED', index=1, number=1, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='PENDING', index=2, number=2, + serialized_options=None, + type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=423, + serialized_end=464, +) +_sym_db.RegisterEnumDescriptor(_EVENTREPLY_STATUS) + +_WATCHCONDITION_CONDITION = _descriptor.EnumDescriptor( + name='Condition', + full_name='debugger.WatchCondition.Condition', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='nan', index=0, number=0, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='inf', index=1, number=1, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='overflow', index=2, number=2, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='max_gt', index=3, number=3, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='max_lt', index=4, number=4, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='min_gt', index=5, number=5, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='min_lt', index=6, number=6, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='max_min_gt', index=7, number=7, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='max_min_lt', index=8, number=8, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='mean_gt', index=9, number=9, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='mean_lt', index=10, number=10, + serialized_options=None, + type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=824, + serialized_end=973, +) +_sym_db.RegisterEnumDescriptor(_WATCHCONDITION_CONDITION) + + +_METADATA = _descriptor.Descriptor( + name='Metadata', + full_name='debugger.Metadata', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='device_name', full_name='debugger.Metadata.device_name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='cur_step', full_name='debugger.Metadata.cur_step', index=1, + number=2, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='backend', full_name='debugger.Metadata.backend', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='cur_node', full_name='debugger.Metadata.cur_node', index=3, + number=4, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='training_done', full_name='debugger.Metadata.training_done', index=4, + number=5, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=100, + serialized_end=207, +) + + +_CHUNK = _descriptor.Descriptor( + name='Chunk', + full_name='debugger.Chunk', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='buffer', full_name='debugger.Chunk.buffer', index=0, + number=1, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=209, + serialized_end=232, +) + + +_EVENTREPLY = _descriptor.Descriptor( + name='EventReply', + full_name='debugger.EventReply', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='status', full_name='debugger.EventReply.status', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='exit', full_name='debugger.EventReply.exit', index=1, + number=2, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='run_cmd', full_name='debugger.EventReply.run_cmd', index=2, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='set_cmd', full_name='debugger.EventReply.set_cmd', index=3, + number=4, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='view_cmd', full_name='debugger.EventReply.view_cmd', index=4, + number=5, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _EVENTREPLY_STATUS, + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='cmd', full_name='debugger.EventReply.cmd', + index=0, containing_type=None, fields=[]), + ], + serialized_start=235, + serialized_end=471, +) + + +_RUNCMD = _descriptor.Descriptor( + name='RunCMD', + full_name='debugger.RunCMD', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='run_level', full_name='debugger.RunCMD.run_level', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='run_steps', full_name='debugger.RunCMD.run_steps', index=1, + number=2, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='node_name', full_name='debugger.RunCMD.node_name', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='cmd', full_name='debugger.RunCMD.cmd', + index=0, containing_type=None, fields=[]), + ], + serialized_start=473, + serialized_end=549, +) + + +_SETCMD = _descriptor.Descriptor( + name='SetCMD', + full_name='debugger.SetCMD', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='watch_nodes', full_name='debugger.SetCMD.watch_nodes', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='watch_condition', full_name='debugger.SetCMD.watch_condition', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='delete', full_name='debugger.SetCMD.delete', index=2, + number=3, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='id', full_name='debugger.SetCMD.id', index=3, + number=4, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=552, + serialized_end=681, +) + + +_VIEWCMD = _descriptor.Descriptor( + name='ViewCMD', + full_name='debugger.ViewCMD', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='tensors', full_name='debugger.ViewCMD.tensors', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=683, + serialized_end=732, +) + + +_WATCHCONDITION = _descriptor.Descriptor( + name='WatchCondition', + full_name='debugger.WatchCondition', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='condition', full_name='debugger.WatchCondition.condition', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='debugger.WatchCondition.value', index=1, + number=2, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _WATCHCONDITION_CONDITION, + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=735, + serialized_end=973, +) + + +_WATCHNODE = _descriptor.Descriptor( + name='WatchNode', + full_name='debugger.WatchNode', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='node_name', full_name='debugger.WatchNode.node_name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='node_type', full_name='debugger.WatchNode.node_type', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=975, + serialized_end=1024, +) + + +_WATCHPOINTHIT = _descriptor.Descriptor( + name='WatchpointHit', + full_name='debugger.WatchpointHit', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='tensor', full_name='debugger.WatchpointHit.tensor', index=0, + number=1, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='watch_condition', full_name='debugger.WatchpointHit.watch_condition', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='id', full_name='debugger.WatchpointHit.id', index=2, + number=3, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1026, + serialized_end=1143, +) + +_EVENTREPLY.fields_by_name['status'].enum_type = _EVENTREPLY_STATUS +_EVENTREPLY.fields_by_name['run_cmd'].message_type = _RUNCMD +_EVENTREPLY.fields_by_name['set_cmd'].message_type = _SETCMD +_EVENTREPLY.fields_by_name['view_cmd'].message_type = _VIEWCMD +_EVENTREPLY_STATUS.containing_type = _EVENTREPLY +_EVENTREPLY.oneofs_by_name['cmd'].fields.append( + _EVENTREPLY.fields_by_name['exit']) +_EVENTREPLY.fields_by_name['exit'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd'] +_EVENTREPLY.oneofs_by_name['cmd'].fields.append( + _EVENTREPLY.fields_by_name['run_cmd']) +_EVENTREPLY.fields_by_name['run_cmd'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd'] +_EVENTREPLY.oneofs_by_name['cmd'].fields.append( + _EVENTREPLY.fields_by_name['set_cmd']) +_EVENTREPLY.fields_by_name['set_cmd'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd'] +_EVENTREPLY.oneofs_by_name['cmd'].fields.append( + _EVENTREPLY.fields_by_name['view_cmd']) +_EVENTREPLY.fields_by_name['view_cmd'].containing_oneof = _EVENTREPLY.oneofs_by_name['cmd'] +_RUNCMD.oneofs_by_name['cmd'].fields.append( + _RUNCMD.fields_by_name['run_steps']) +_RUNCMD.fields_by_name['run_steps'].containing_oneof = _RUNCMD.oneofs_by_name['cmd'] +_RUNCMD.oneofs_by_name['cmd'].fields.append( + _RUNCMD.fields_by_name['node_name']) +_RUNCMD.fields_by_name['node_name'].containing_oneof = _RUNCMD.oneofs_by_name['cmd'] +_SETCMD.fields_by_name['watch_nodes'].message_type = _WATCHNODE +_SETCMD.fields_by_name['watch_condition'].message_type = _WATCHCONDITION +_VIEWCMD.fields_by_name['tensors'].message_type = mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2._TENSORPROTO +_WATCHCONDITION.fields_by_name['condition'].enum_type = _WATCHCONDITION_CONDITION +_WATCHCONDITION_CONDITION.containing_type = _WATCHCONDITION +_WATCHPOINTHIT.fields_by_name['tensor'].message_type = mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2._TENSORPROTO +_WATCHPOINTHIT.fields_by_name['watch_condition'].message_type = _WATCHCONDITION +DESCRIPTOR.message_types_by_name['Metadata'] = _METADATA +DESCRIPTOR.message_types_by_name['Chunk'] = _CHUNK +DESCRIPTOR.message_types_by_name['EventReply'] = _EVENTREPLY +DESCRIPTOR.message_types_by_name['RunCMD'] = _RUNCMD +DESCRIPTOR.message_types_by_name['SetCMD'] = _SETCMD +DESCRIPTOR.message_types_by_name['ViewCMD'] = _VIEWCMD +DESCRIPTOR.message_types_by_name['WatchCondition'] = _WATCHCONDITION +DESCRIPTOR.message_types_by_name['WatchNode'] = _WATCHNODE +DESCRIPTOR.message_types_by_name['WatchpointHit'] = _WATCHPOINTHIT +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +Metadata = _reflection.GeneratedProtocolMessageType('Metadata', (_message.Message,), { + 'DESCRIPTOR' : _METADATA, + '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' + # @@protoc_insertion_point(class_scope:debugger.Metadata) + }) +_sym_db.RegisterMessage(Metadata) + +Chunk = _reflection.GeneratedProtocolMessageType('Chunk', (_message.Message,), { + 'DESCRIPTOR' : _CHUNK, + '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' + # @@protoc_insertion_point(class_scope:debugger.Chunk) + }) +_sym_db.RegisterMessage(Chunk) + +EventReply = _reflection.GeneratedProtocolMessageType('EventReply', (_message.Message,), { + 'DESCRIPTOR' : _EVENTREPLY, + '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' + # @@protoc_insertion_point(class_scope:debugger.EventReply) + }) +_sym_db.RegisterMessage(EventReply) + +RunCMD = _reflection.GeneratedProtocolMessageType('RunCMD', (_message.Message,), { + 'DESCRIPTOR' : _RUNCMD, + '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' + # @@protoc_insertion_point(class_scope:debugger.RunCMD) + }) +_sym_db.RegisterMessage(RunCMD) + +SetCMD = _reflection.GeneratedProtocolMessageType('SetCMD', (_message.Message,), { + 'DESCRIPTOR' : _SETCMD, + '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' + # @@protoc_insertion_point(class_scope:debugger.SetCMD) + }) +_sym_db.RegisterMessage(SetCMD) + +ViewCMD = _reflection.GeneratedProtocolMessageType('ViewCMD', (_message.Message,), { + 'DESCRIPTOR' : _VIEWCMD, + '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' + # @@protoc_insertion_point(class_scope:debugger.ViewCMD) + }) +_sym_db.RegisterMessage(ViewCMD) + +WatchCondition = _reflection.GeneratedProtocolMessageType('WatchCondition', (_message.Message,), { + 'DESCRIPTOR' : _WATCHCONDITION, + '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' + # @@protoc_insertion_point(class_scope:debugger.WatchCondition) + }) +_sym_db.RegisterMessage(WatchCondition) + +WatchNode = _reflection.GeneratedProtocolMessageType('WatchNode', (_message.Message,), { + 'DESCRIPTOR' : _WATCHNODE, + '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' + # @@protoc_insertion_point(class_scope:debugger.WatchNode) + }) +_sym_db.RegisterMessage(WatchNode) + +WatchpointHit = _reflection.GeneratedProtocolMessageType('WatchpointHit', (_message.Message,), { + 'DESCRIPTOR' : _WATCHPOINTHIT, + '__module__' : 'mindinsight.debugger.proto.debug_grpc_pb2' + # @@protoc_insertion_point(class_scope:debugger.WatchpointHit) + }) +_sym_db.RegisterMessage(WatchpointHit) + + + +_EVENTLISTENER = _descriptor.ServiceDescriptor( + name='EventListener', + full_name='debugger.EventListener', + file=DESCRIPTOR, + index=0, + serialized_options=None, + serialized_start=1146, + serialized_end=1469, + methods=[ + _descriptor.MethodDescriptor( + name='WaitCMD', + full_name='debugger.EventListener.WaitCMD', + index=0, + containing_service=None, + input_type=_METADATA, + output_type=_EVENTREPLY, + serialized_options=None, + ), + _descriptor.MethodDescriptor( + name='SendMetadata', + full_name='debugger.EventListener.SendMetadata', + index=1, + containing_service=None, + input_type=_METADATA, + output_type=_EVENTREPLY, + serialized_options=None, + ), + _descriptor.MethodDescriptor( + name='SendGraph', + full_name='debugger.EventListener.SendGraph', + index=2, + containing_service=None, + input_type=_CHUNK, + output_type=_EVENTREPLY, + serialized_options=None, + ), + _descriptor.MethodDescriptor( + name='SendTensors', + full_name='debugger.EventListener.SendTensors', + index=3, + containing_service=None, + input_type=mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2._TENSORPROTO, + output_type=_EVENTREPLY, + serialized_options=None, + ), + _descriptor.MethodDescriptor( + name='SendWatchpointHits', + full_name='debugger.EventListener.SendWatchpointHits', + index=4, + containing_service=None, + input_type=_WATCHPOINTHIT, + output_type=_EVENTREPLY, + serialized_options=None, + ), +]) +_sym_db.RegisterServiceDescriptor(_EVENTLISTENER) + +DESCRIPTOR.services_by_name['EventListener'] = _EVENTLISTENER + +# @@protoc_insertion_point(module_scope) diff --git a/mindinsight/debugger/proto/debug_grpc_pb2_grpc.py b/mindinsight/debugger/proto/debug_grpc_pb2_grpc.py new file mode 100644 index 00000000..9abc88ce --- /dev/null +++ b/mindinsight/debugger/proto/debug_grpc_pb2_grpc.py @@ -0,0 +1,193 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +import grpc + +from mindinsight.debugger.proto import debug_grpc_pb2 as mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2 +from mindinsight.debugger.proto import ms_graph_pb2 as mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2 + + +class EventListenerStub(object): + """Missing associated documentation comment in .proto file""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.WaitCMD = channel.unary_unary( + '/debugger.EventListener/WaitCMD', + request_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, + response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, + ) + self.SendMetadata = channel.unary_unary( + '/debugger.EventListener/SendMetadata', + request_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, + response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, + ) + self.SendGraph = channel.stream_unary( + '/debugger.EventListener/SendGraph', + request_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString, + response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, + ) + self.SendTensors = channel.stream_unary( + '/debugger.EventListener/SendTensors', + request_serializer=mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.SerializeToString, + response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, + ) + self.SendWatchpointHits = channel.stream_unary( + '/debugger.EventListener/SendWatchpointHits', + request_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.SerializeToString, + response_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, + ) + + +class EventListenerServicer(object): + """Missing associated documentation comment in .proto file""" + + def WaitCMD(self, request, context): + """Missing associated documentation comment in .proto file""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendMetadata(self, request, context): + """Missing associated documentation comment in .proto file""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendGraph(self, request_iterator, context): + """Missing associated documentation comment in .proto file""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendTensors(self, request_iterator, context): + """Missing associated documentation comment in .proto file""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendWatchpointHits(self, request_iterator, context): + """Missing associated documentation comment in .proto file""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_EventListenerServicer_to_server(servicer, server): + rpc_method_handlers = { + 'WaitCMD': grpc.unary_unary_rpc_method_handler( + servicer.WaitCMD, + request_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.FromString, + response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString, + ), + 'SendMetadata': grpc.unary_unary_rpc_method_handler( + servicer.SendMetadata, + request_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.FromString, + response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString, + ), + 'SendGraph': grpc.stream_unary_rpc_method_handler( + servicer.SendGraph, + request_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.FromString, + response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString, + ), + 'SendTensors': grpc.stream_unary_rpc_method_handler( + servicer.SendTensors, + request_deserializer=mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.FromString, + response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString, + ), + 'SendWatchpointHits': grpc.stream_unary_rpc_method_handler( + servicer.SendWatchpointHits, + request_deserializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.FromString, + response_serializer=mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'debugger.EventListener', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class EventListener(object): + """Missing associated documentation comment in .proto file""" + + @staticmethod + def WaitCMD(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/debugger.EventListener/WaitCMD', + mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, + mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, + options, channel_credentials, + call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def SendMetadata(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/debugger.EventListener/SendMetadata', + mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Metadata.SerializeToString, + mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, + options, channel_credentials, + call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def SendGraph(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary(request_iterator, target, '/debugger.EventListener/SendGraph', + mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.Chunk.SerializeToString, + mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, + options, channel_credentials, + call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def SendTensors(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary(request_iterator, target, '/debugger.EventListener/SendTensors', + mindinsight_dot_debugger_dot_proto_dot_ms__graph__pb2.TensorProto.SerializeToString, + mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, + options, channel_credentials, + call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def SendWatchpointHits(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary(request_iterator, target, '/debugger.EventListener/SendWatchpointHits', + mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.WatchpointHit.SerializeToString, + mindinsight_dot_debugger_dot_proto_dot_debug__grpc__pb2.EventReply.FromString, + options, channel_credentials, + call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/mindinsight/debugger/proto/ms_graph.proto b/mindinsight/debugger/proto/ms_graph.proto new file mode 100644 index 00000000..0a17f460 --- /dev/null +++ b/mindinsight/debugger/proto/ms_graph.proto @@ -0,0 +1,322 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto2"; + +package debugger; + +// Versioning +enum Version { + // unknown version + UNKNOWWN_VERSION = 0; + + // Initial version (IR VERSION 1), published on Sep 23, 2019 + IR_VERSION = 0x0000000000000001; +} + +// Data type definition +enum DataType { + DT_UNDEFINED = 0; + // Basic types. + DT_BOOL = 1; // bool + + DT_INT8 = 2; // int8_t + DT_INT16 = 3; // int16_t + DT_INT32 = 4; // int32_t + DT_INT64 = 5; // int64_t + + DT_UINT8 = 6; // uint8_t + DT_UINT16 = 7; // uint16_t + DT_UINT32 = 8; // uint32_t + DT_UINT64 = 9; // uint64_t + + DT_FLOAT16 = 10; // float 16 + DT_FLOAT32 = 11; // float 32 + DT_FLOAT64 = 12; // float 64 + + DT_STRING = 13; // string + DT_TENSOR = 14; // tensor + DT_GRAPH = 15; // graph + + // list type + DT_BOOLS = 16; // list of bool + + DT_INTS8 = 17; // list of int8_t + DT_INTS16 = 18; // list of int16_t + DT_INTS32 = 19; // list of int32_t + DT_INTS64 = 20; // list of int64_t + + DT_UINTS8 = 21; // list of uint8_t + DT_UINTS16 = 22; // list of uint16_t + DT_UINTS32 = 23; // list of uint32_t + DT_UINTS64 = 24; // list of uint64_t + + DT_FLOATS16 = 25; // list of float16 + DT_FLOATS32 = 26; // list of float32 + DT_FLOATS64 = 27; // list of float64 + + DT_STRINGS = 28; // list of string + DT_TENSORS = 29; // list of tensor + DT_GRAPHS = 30; // list of graph + + DT_TUPLE = 31; // tuple + DT_LIST = 32; // list + DT_DICT = 33; // dictionary + + // other types + DT_NONE = 34; // None + DT_SYM_INST = 35; // Symbolic Key Instance + + // type related type + DT_BASE_INT = 36; // type generic int + DT_BASE_UINT = 37; // type generate unsigned int + DT_BASE_FLOAT = 38; // type generate float + DT_TYPE = 39; // type type + DT_ANYTHING = 40; // type anything + DT_REFKEY = 41; // type refkey + DT_REF = 42; // type ref +} + +// Value definition for attribute value or parameter default value +message ValueProto { + // data type of value + optional DataType dtype = 1; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + optional bool bool_val = 2; // bool + optional int64 int_val = 3; // int + optional uint64 uint_val = 4; // uint + optional float float_val = 5; // float + optional double double_val = 6; // double + optional string str_val = 7; // string + optional TensorProto tensor_val = 8; // tensor value + optional GraphProto graph = 9; // graph + + repeated bool bool_vals = 10; // list of bool + repeated int64 int_vals = 11; // list of int + repeated uint64 uint_vals = 12; // list of uint + repeated float float_vals = 13; // list of float + repeated double double_vals = 14; // list of double + repeated string str_vals = 15; // list of string + repeated TensorProto tensor_vals = 16; // list of tensor value + repeated GraphProto graphs = 17; // list of graph + + // tuple or list + repeated ValueProto values = 18; // tuple, list of value + + // dictionary + repeated NamedValueProto dict_val = 19; // dictionary info + + // filed for type type + optional TypeProto type_val = 20; // type type info +} + +message AttributeProto { + optional string name = 1; // attribute name + optional ValueProto value = 2; // attribute value +} + +message NamedValueProto { + optional string key = 1; // attribute name + optional ValueProto value = 2; // attribute value +} + +// Defines a tensor shape. +message TensorShapeProto { + // One dimension of the tensor. + message Dimension { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). + optional int64 size = 1; + + // Optional name of the tensor dimension. + optional string name = 2; + }; + + repeated Dimension dim = 1; +} + +// Types for graph input(parameter) and output +message TypeProto { + + message Tensor { + // This field MUST have a valid DataType value except DT_TENSOR + optional DataType elem_type = 1; + optional TensorShapeProto shape = 2; // for scalar, this field is not set + } + + // tuple type + message Sequence { + // The type and optional shape of elements of the tuple. + repeated TypeProto elem_types = 1; + }; + + // data type + optional DataType data_type = 1; + + oneof value { + // The type of a tensor. + Tensor tensor_type = 2; + + // The type of a tuple. + Sequence sequence_type = 3; + } +} + +// Defines information on graph parameters, including the name, the type, and +// the default value of parameter if exists. +message ParameterProto { + optional string name = 1; // parameter name + optional TypeProto type = 2; // parameter type + optional ValueProto default_val = 3; // default value of parameter if exists +} + +// Defines graph output information +message OutputProto { + optional string name = 1; // output node name + optional TypeProto type = 2; // output node type +} + +// Define node input information +message InputProto { + enum EdgeType { + DATA_EDGE = 0; // data edge + CONTROL_EDGE = 1; // control edge + } + + optional string name = 1; + optional EdgeType type = 2; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated InputProto input = 1; // namespace Value + optional string name = 2; // namespace Value + + // The symbolic identifier of the Operator to execute. + optional string op_type = 3; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + optional string scope = 4; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // Optional type info of this node + optional TypeProto output_type = 6; + + // other fields for debug + optional uint64 output_i = 7; + + // full name with scope + optional string full_name = 8; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // ir version + optional int64 ir_version = 1; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + optional string domain = 2; + + // The version of the graph encoded. See Version enum below. + optional int64 model_version = 3; + + // The parameterized graph that is evaluated to execute the model. + optional GraphProto graph = 4; + + // metadata info of opeartors + optional OperatorSetProto metadata_operators = 5; +}; + +message OperatorProto { + optional string name = 1; // used as key, must be distinct + optional bytes config = 2; // operator config info + optional bytes obj_info = 3; // operator related object info, e.g. content of operator binary or name +}; + +message OperatorSetProto { + repeated OperatorProto operators = 1; + optional string summary = 2; // summary info of operators, e.g. file position of operators file +} + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + optional string name = 2; // namespace Graph + + // The parameters(inputs) and outputs of the graph. + repeated ParameterProto parameters = 3; + repeated OutputProto outputs = 4; + + // Constants used in this graph + repeated NamedValueProto const_vals = 5; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + // The node name of the tensor. + optional string node_name = 1; + + // The slot of the tensor in its node. + optional string slot = 2; + + // The serialized tensor content. + optional bytes tensor_content = 3; + + // The shape of the tensor. + repeated int64 dims = 4; + + // The data type of the tensor. + // This field MUST have a valid DataType value except DT_TENSOR + optional DataType data_type = 5; + + // If the tensor content transferring is finished. + optional bool finished = 6; + + // The iteration of the tensor. Supported: "prev" or leave empty. + optional string iter = 7; + + // If the tensor name should be truncated. + optional bool truncate = 8; +} \ No newline at end of file diff --git a/mindinsight/debugger/proto/ms_graph_pb2.py b/mindinsight/debugger/proto/ms_graph_pb2.py new file mode 100644 index 00000000..d2a791d1 --- /dev/null +++ b/mindinsight/debugger/proto/ms_graph_pb2.py @@ -0,0 +1,1395 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: mindinsight/debugger/proto/ms_graph.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf.internal import enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='mindinsight/debugger/proto/ms_graph.proto', + package='debugger', + syntax='proto2', + serialized_options=None, + serialized_pb=_b('\n)mindinsight/debugger/proto/ms_graph.proto\x12\x08\x64\x65\x62ugger\"\xab\x04\n\nValueProto\x12!\n\x05\x64type\x18\x01 \x01(\x0e\x32\x12.debugger.DataType\x12\x10\n\x08\x62ool_val\x18\x02 \x01(\x08\x12\x0f\n\x07int_val\x18\x03 \x01(\x03\x12\x10\n\x08uint_val\x18\x04 \x01(\x04\x12\x11\n\tfloat_val\x18\x05 \x01(\x02\x12\x12\n\ndouble_val\x18\x06 \x01(\x01\x12\x0f\n\x07str_val\x18\x07 \x01(\t\x12)\n\ntensor_val\x18\x08 \x01(\x0b\x32\x15.debugger.TensorProto\x12#\n\x05graph\x18\t \x01(\x0b\x32\x14.debugger.GraphProto\x12\x11\n\tbool_vals\x18\n \x03(\x08\x12\x10\n\x08int_vals\x18\x0b \x03(\x03\x12\x11\n\tuint_vals\x18\x0c \x03(\x04\x12\x12\n\nfloat_vals\x18\r \x03(\x02\x12\x13\n\x0b\x64ouble_vals\x18\x0e \x03(\x01\x12\x10\n\x08str_vals\x18\x0f \x03(\t\x12*\n\x0btensor_vals\x18\x10 \x03(\x0b\x32\x15.debugger.TensorProto\x12$\n\x06graphs\x18\x11 \x03(\x0b\x32\x14.debugger.GraphProto\x12$\n\x06values\x18\x12 \x03(\x0b\x32\x14.debugger.ValueProto\x12+\n\x08\x64ict_val\x18\x13 \x03(\x0b\x32\x19.debugger.NamedValueProto\x12%\n\x08type_val\x18\x14 \x01(\x0b\x32\x13.debugger.TypeProto\"C\n\x0e\x41ttributeProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.debugger.ValueProto\"C\n\x0fNamedValueProto\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.debugger.ValueProto\"n\n\x10TensorShapeProto\x12\x31\n\x03\x64im\x18\x01 \x03(\x0b\x32$.debugger.TensorShapeProto.Dimension\x1a\'\n\tDimension\x12\x0c\n\x04size\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\t\"\xb6\x02\n\tTypeProto\x12%\n\tdata_type\x18\x01 \x01(\x0e\x32\x12.debugger.DataType\x12\x31\n\x0btensor_type\x18\x02 \x01(\x0b\x32\x1a.debugger.TypeProto.TensorH\x00\x12\x35\n\rsequence_type\x18\x03 \x01(\x0b\x32\x1c.debugger.TypeProto.SequenceH\x00\x1aZ\n\x06Tensor\x12%\n\telem_type\x18\x01 \x01(\x0e\x32\x12.debugger.DataType\x12)\n\x05shape\x18\x02 \x01(\x0b\x32\x1a.debugger.TensorShapeProto\x1a\x33\n\x08Sequence\x12\'\n\nelem_types\x18\x01 \x03(\x0b\x32\x13.debugger.TypeProtoB\x07\n\x05value\"l\n\x0eParameterProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12!\n\x04type\x18\x02 \x01(\x0b\x32\x13.debugger.TypeProto\x12)\n\x0b\x64\x65\x66\x61ult_val\x18\x03 \x01(\x0b\x32\x14.debugger.ValueProto\">\n\x0bOutputProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12!\n\x04type\x18\x02 \x01(\x0b\x32\x13.debugger.TypeProto\"t\n\nInputProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12+\n\x04type\x18\x02 \x01(\x0e\x32\x1d.debugger.InputProto.EdgeType\"+\n\x08\x45\x64geType\x12\r\n\tDATA_EDGE\x10\x00\x12\x10\n\x0c\x43ONTROL_EDGE\x10\x01\"\xda\x01\n\tNodeProto\x12#\n\x05input\x18\x01 \x03(\x0b\x32\x14.debugger.InputProto\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0f\n\x07op_type\x18\x03 \x01(\t\x12\r\n\x05scope\x18\x04 \x01(\t\x12+\n\tattribute\x18\x05 \x03(\x0b\x32\x18.debugger.AttributeProto\x12(\n\x0boutput_type\x18\x06 \x01(\x0b\x32\x13.debugger.TypeProto\x12\x10\n\x08output_i\x18\x07 \x01(\x04\x12\x11\n\tfull_name\x18\x08 \x01(\t\"\xa4\x01\n\nModelProto\x12\x12\n\nir_version\x18\x01 \x01(\x03\x12\x0e\n\x06\x64omain\x18\x02 \x01(\t\x12\x15\n\rmodel_version\x18\x03 \x01(\x03\x12#\n\x05graph\x18\x04 \x01(\x0b\x32\x14.debugger.GraphProto\x12\x36\n\x12metadata_operators\x18\x05 \x01(\x0b\x32\x1a.debugger.OperatorSetProto\"?\n\rOperatorProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06\x63onfig\x18\x02 \x01(\x0c\x12\x10\n\x08obj_info\x18\x03 \x01(\x0c\"O\n\x10OperatorSetProto\x12*\n\toperators\x18\x01 \x03(\x0b\x32\x17.debugger.OperatorProto\x12\x0f\n\x07summary\x18\x02 \x01(\t\"\xc2\x01\n\nGraphProto\x12!\n\x04node\x18\x01 \x03(\x0b\x32\x13.debugger.NodeProto\x12\x0c\n\x04name\x18\x02 \x01(\t\x12,\n\nparameters\x18\x03 \x03(\x0b\x32\x18.debugger.ParameterProto\x12&\n\x07outputs\x18\x04 \x03(\x0b\x32\x15.debugger.OutputProto\x12-\n\nconst_vals\x18\x05 \x03(\x0b\x32\x19.debugger.NamedValueProto\"\xad\x01\n\x0bTensorProto\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x0c\n\x04slot\x18\x02 \x01(\t\x12\x16\n\x0etensor_content\x18\x03 \x01(\x0c\x12\x0c\n\x04\x64ims\x18\x04 \x03(\x03\x12%\n\tdata_type\x18\x05 \x01(\x0e\x32\x12.debugger.DataType\x12\x10\n\x08\x66inished\x18\x06 \x01(\x08\x12\x0c\n\x04iter\x18\x07 \x01(\t\x12\x10\n\x08truncate\x18\x08 \x01(\x08*/\n\x07Version\x12\x14\n\x10UNKNOWWN_VERSION\x10\x00\x12\x0e\n\nIR_VERSION\x10\x01*\x96\x05\n\x08\x44\x61taType\x12\x10\n\x0c\x44T_UNDEFINED\x10\x00\x12\x0b\n\x07\x44T_BOOL\x10\x01\x12\x0b\n\x07\x44T_INT8\x10\x02\x12\x0c\n\x08\x44T_INT16\x10\x03\x12\x0c\n\x08\x44T_INT32\x10\x04\x12\x0c\n\x08\x44T_INT64\x10\x05\x12\x0c\n\x08\x44T_UINT8\x10\x06\x12\r\n\tDT_UINT16\x10\x07\x12\r\n\tDT_UINT32\x10\x08\x12\r\n\tDT_UINT64\x10\t\x12\x0e\n\nDT_FLOAT16\x10\n\x12\x0e\n\nDT_FLOAT32\x10\x0b\x12\x0e\n\nDT_FLOAT64\x10\x0c\x12\r\n\tDT_STRING\x10\r\x12\r\n\tDT_TENSOR\x10\x0e\x12\x0c\n\x08\x44T_GRAPH\x10\x0f\x12\x0c\n\x08\x44T_BOOLS\x10\x10\x12\x0c\n\x08\x44T_INTS8\x10\x11\x12\r\n\tDT_INTS16\x10\x12\x12\r\n\tDT_INTS32\x10\x13\x12\r\n\tDT_INTS64\x10\x14\x12\r\n\tDT_UINTS8\x10\x15\x12\x0e\n\nDT_UINTS16\x10\x16\x12\x0e\n\nDT_UINTS32\x10\x17\x12\x0e\n\nDT_UINTS64\x10\x18\x12\x0f\n\x0b\x44T_FLOATS16\x10\x19\x12\x0f\n\x0b\x44T_FLOATS32\x10\x1a\x12\x0f\n\x0b\x44T_FLOATS64\x10\x1b\x12\x0e\n\nDT_STRINGS\x10\x1c\x12\x0e\n\nDT_TENSORS\x10\x1d\x12\r\n\tDT_GRAPHS\x10\x1e\x12\x0c\n\x08\x44T_TUPLE\x10\x1f\x12\x0b\n\x07\x44T_LIST\x10 \x12\x0b\n\x07\x44T_DICT\x10!\x12\x0b\n\x07\x44T_NONE\x10\"\x12\x0f\n\x0b\x44T_SYM_INST\x10#\x12\x0f\n\x0b\x44T_BASE_INT\x10$\x12\x10\n\x0c\x44T_BASE_UINT\x10%\x12\x11\n\rDT_BASE_FLOAT\x10&\x12\x0b\n\x07\x44T_TYPE\x10\'\x12\x0f\n\x0b\x44T_ANYTHING\x10(\x12\r\n\tDT_REFKEY\x10)\x12\n\n\x06\x44T_REF\x10*') +) + +_VERSION = _descriptor.EnumDescriptor( + name='Version', + full_name='debugger.Version', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='UNKNOWWN_VERSION', index=0, number=0, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='IR_VERSION', index=1, number=1, + serialized_options=None, + type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=2375, + serialized_end=2422, +) +_sym_db.RegisterEnumDescriptor(_VERSION) + +Version = enum_type_wrapper.EnumTypeWrapper(_VERSION) +_DATATYPE = _descriptor.EnumDescriptor( + name='DataType', + full_name='debugger.DataType', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='DT_UNDEFINED', index=0, number=0, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_BOOL', index=1, number=1, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_INT8', index=2, number=2, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_INT16', index=3, number=3, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_INT32', index=4, number=4, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_INT64', index=5, number=5, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_UINT8', index=6, number=6, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_UINT16', index=7, number=7, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_UINT32', index=8, number=8, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_UINT64', index=9, number=9, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_FLOAT16', index=10, number=10, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_FLOAT32', index=11, number=11, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_FLOAT64', index=12, number=12, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_STRING', index=13, number=13, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_TENSOR', index=14, number=14, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_GRAPH', index=15, number=15, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_BOOLS', index=16, number=16, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_INTS8', index=17, number=17, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_INTS16', index=18, number=18, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_INTS32', index=19, number=19, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_INTS64', index=20, number=20, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_UINTS8', index=21, number=21, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_UINTS16', index=22, number=22, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_UINTS32', index=23, number=23, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_UINTS64', index=24, number=24, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_FLOATS16', index=25, number=25, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_FLOATS32', index=26, number=26, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_FLOATS64', index=27, number=27, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_STRINGS', index=28, number=28, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_TENSORS', index=29, number=29, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_GRAPHS', index=30, number=30, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_TUPLE', index=31, number=31, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_LIST', index=32, number=32, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_DICT', index=33, number=33, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_NONE', index=34, number=34, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_SYM_INST', index=35, number=35, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_BASE_INT', index=36, number=36, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_BASE_UINT', index=37, number=37, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_BASE_FLOAT', index=38, number=38, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_TYPE', index=39, number=39, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_ANYTHING', index=40, number=40, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_REFKEY', index=41, number=41, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_REF', index=42, number=42, + serialized_options=None, + type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=2425, + serialized_end=3087, +) +_sym_db.RegisterEnumDescriptor(_DATATYPE) + +DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE) +UNKNOWWN_VERSION = 0 +IR_VERSION = 1 +DT_UNDEFINED = 0 +DT_BOOL = 1 +DT_INT8 = 2 +DT_INT16 = 3 +DT_INT32 = 4 +DT_INT64 = 5 +DT_UINT8 = 6 +DT_UINT16 = 7 +DT_UINT32 = 8 +DT_UINT64 = 9 +DT_FLOAT16 = 10 +DT_FLOAT32 = 11 +DT_FLOAT64 = 12 +DT_STRING = 13 +DT_TENSOR = 14 +DT_GRAPH = 15 +DT_BOOLS = 16 +DT_INTS8 = 17 +DT_INTS16 = 18 +DT_INTS32 = 19 +DT_INTS64 = 20 +DT_UINTS8 = 21 +DT_UINTS16 = 22 +DT_UINTS32 = 23 +DT_UINTS64 = 24 +DT_FLOATS16 = 25 +DT_FLOATS32 = 26 +DT_FLOATS64 = 27 +DT_STRINGS = 28 +DT_TENSORS = 29 +DT_GRAPHS = 30 +DT_TUPLE = 31 +DT_LIST = 32 +DT_DICT = 33 +DT_NONE = 34 +DT_SYM_INST = 35 +DT_BASE_INT = 36 +DT_BASE_UINT = 37 +DT_BASE_FLOAT = 38 +DT_TYPE = 39 +DT_ANYTHING = 40 +DT_REFKEY = 41 +DT_REF = 42 + + +_INPUTPROTO_EDGETYPE = _descriptor.EnumDescriptor( + name='EdgeType', + full_name='debugger.InputProto.EdgeType', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='DATA_EDGE', index=0, number=0, + serialized_options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='CONTROL_EDGE', index=1, number=1, + serialized_options=None, + type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=1423, + serialized_end=1466, +) +_sym_db.RegisterEnumDescriptor(_INPUTPROTO_EDGETYPE) + + +_VALUEPROTO = _descriptor.Descriptor( + name='ValueProto', + full_name='debugger.ValueProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='dtype', full_name='debugger.ValueProto.dtype', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='bool_val', full_name='debugger.ValueProto.bool_val', index=1, + number=2, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='int_val', full_name='debugger.ValueProto.int_val', index=2, + number=3, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='uint_val', full_name='debugger.ValueProto.uint_val', index=3, + number=4, type=4, cpp_type=4, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='float_val', full_name='debugger.ValueProto.float_val', index=4, + number=5, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='double_val', full_name='debugger.ValueProto.double_val', index=5, + number=6, type=1, cpp_type=5, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='str_val', full_name='debugger.ValueProto.str_val', index=6, + number=7, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='tensor_val', full_name='debugger.ValueProto.tensor_val', index=7, + number=8, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='graph', full_name='debugger.ValueProto.graph', index=8, + number=9, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='bool_vals', full_name='debugger.ValueProto.bool_vals', index=9, + number=10, type=8, cpp_type=7, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='int_vals', full_name='debugger.ValueProto.int_vals', index=10, + number=11, type=3, cpp_type=2, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='uint_vals', full_name='debugger.ValueProto.uint_vals', index=11, + number=12, type=4, cpp_type=4, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='float_vals', full_name='debugger.ValueProto.float_vals', index=12, + number=13, type=2, cpp_type=6, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='double_vals', full_name='debugger.ValueProto.double_vals', index=13, + number=14, type=1, cpp_type=5, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='str_vals', full_name='debugger.ValueProto.str_vals', index=14, + number=15, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='tensor_vals', full_name='debugger.ValueProto.tensor_vals', index=15, + number=16, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='graphs', full_name='debugger.ValueProto.graphs', index=16, + number=17, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='values', full_name='debugger.ValueProto.values', index=17, + number=18, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='dict_val', full_name='debugger.ValueProto.dict_val', index=18, + number=19, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='type_val', full_name='debugger.ValueProto.type_val', index=19, + number=20, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=56, + serialized_end=611, +) + + +_ATTRIBUTEPROTO = _descriptor.Descriptor( + name='AttributeProto', + full_name='debugger.AttributeProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='debugger.AttributeProto.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='debugger.AttributeProto.value', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=613, + serialized_end=680, +) + + +_NAMEDVALUEPROTO = _descriptor.Descriptor( + name='NamedValueProto', + full_name='debugger.NamedValueProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='key', full_name='debugger.NamedValueProto.key', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='value', full_name='debugger.NamedValueProto.value', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=682, + serialized_end=749, +) + + +_TENSORSHAPEPROTO_DIMENSION = _descriptor.Descriptor( + name='Dimension', + full_name='debugger.TensorShapeProto.Dimension', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='size', full_name='debugger.TensorShapeProto.Dimension.size', index=0, + number=1, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='name', full_name='debugger.TensorShapeProto.Dimension.name', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=822, + serialized_end=861, +) + +_TENSORSHAPEPROTO = _descriptor.Descriptor( + name='TensorShapeProto', + full_name='debugger.TensorShapeProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='dim', full_name='debugger.TensorShapeProto.dim', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_TENSORSHAPEPROTO_DIMENSION, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=751, + serialized_end=861, +) + + +_TYPEPROTO_TENSOR = _descriptor.Descriptor( + name='Tensor', + full_name='debugger.TypeProto.Tensor', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='elem_type', full_name='debugger.TypeProto.Tensor.elem_type', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='shape', full_name='debugger.TypeProto.Tensor.shape', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1022, + serialized_end=1112, +) + +_TYPEPROTO_SEQUENCE = _descriptor.Descriptor( + name='Sequence', + full_name='debugger.TypeProto.Sequence', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='elem_types', full_name='debugger.TypeProto.Sequence.elem_types', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1114, + serialized_end=1165, +) + +_TYPEPROTO = _descriptor.Descriptor( + name='TypeProto', + full_name='debugger.TypeProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='data_type', full_name='debugger.TypeProto.data_type', index=0, + number=1, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='tensor_type', full_name='debugger.TypeProto.tensor_type', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='sequence_type', full_name='debugger.TypeProto.sequence_type', index=2, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[_TYPEPROTO_TENSOR, _TYPEPROTO_SEQUENCE, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + _descriptor.OneofDescriptor( + name='value', full_name='debugger.TypeProto.value', + index=0, containing_type=None, fields=[]), + ], + serialized_start=864, + serialized_end=1174, +) + + +_PARAMETERPROTO = _descriptor.Descriptor( + name='ParameterProto', + full_name='debugger.ParameterProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='debugger.ParameterProto.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='type', full_name='debugger.ParameterProto.type', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='default_val', full_name='debugger.ParameterProto.default_val', index=2, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1176, + serialized_end=1284, +) + + +_OUTPUTPROTO = _descriptor.Descriptor( + name='OutputProto', + full_name='debugger.OutputProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='debugger.OutputProto.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='type', full_name='debugger.OutputProto.type', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1286, + serialized_end=1348, +) + + +_INPUTPROTO = _descriptor.Descriptor( + name='InputProto', + full_name='debugger.InputProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='debugger.InputProto.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='type', full_name='debugger.InputProto.type', index=1, + number=2, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _INPUTPROTO_EDGETYPE, + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1350, + serialized_end=1466, +) + + +_NODEPROTO = _descriptor.Descriptor( + name='NodeProto', + full_name='debugger.NodeProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='input', full_name='debugger.NodeProto.input', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='name', full_name='debugger.NodeProto.name', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='op_type', full_name='debugger.NodeProto.op_type', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='scope', full_name='debugger.NodeProto.scope', index=3, + number=4, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='attribute', full_name='debugger.NodeProto.attribute', index=4, + number=5, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='output_type', full_name='debugger.NodeProto.output_type', index=5, + number=6, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='output_i', full_name='debugger.NodeProto.output_i', index=6, + number=7, type=4, cpp_type=4, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='full_name', full_name='debugger.NodeProto.full_name', index=7, + number=8, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1469, + serialized_end=1687, +) + + +_MODELPROTO = _descriptor.Descriptor( + name='ModelProto', + full_name='debugger.ModelProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='ir_version', full_name='debugger.ModelProto.ir_version', index=0, + number=1, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='domain', full_name='debugger.ModelProto.domain', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='model_version', full_name='debugger.ModelProto.model_version', index=2, + number=3, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='graph', full_name='debugger.ModelProto.graph', index=3, + number=4, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='metadata_operators', full_name='debugger.ModelProto.metadata_operators', index=4, + number=5, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1690, + serialized_end=1854, +) + + +_OPERATORPROTO = _descriptor.Descriptor( + name='OperatorProto', + full_name='debugger.OperatorProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='debugger.OperatorProto.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='config', full_name='debugger.OperatorProto.config', index=1, + number=2, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=_b(""), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='obj_info', full_name='debugger.OperatorProto.obj_info', index=2, + number=3, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=_b(""), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1856, + serialized_end=1919, +) + + +_OPERATORSETPROTO = _descriptor.Descriptor( + name='OperatorSetProto', + full_name='debugger.OperatorSetProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='operators', full_name='debugger.OperatorSetProto.operators', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='summary', full_name='debugger.OperatorSetProto.summary', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1921, + serialized_end=2000, +) + + +_GRAPHPROTO = _descriptor.Descriptor( + name='GraphProto', + full_name='debugger.GraphProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='node', full_name='debugger.GraphProto.node', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='name', full_name='debugger.GraphProto.name', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='parameters', full_name='debugger.GraphProto.parameters', index=2, + number=3, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='outputs', full_name='debugger.GraphProto.outputs', index=3, + number=4, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='const_vals', full_name='debugger.GraphProto.const_vals', index=4, + number=5, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=2003, + serialized_end=2197, +) + + +_TENSORPROTO = _descriptor.Descriptor( + name='TensorProto', + full_name='debugger.TensorProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='node_name', full_name='debugger.TensorProto.node_name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='slot', full_name='debugger.TensorProto.slot', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='tensor_content', full_name='debugger.TensorProto.tensor_content', index=2, + number=3, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=_b(""), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='dims', full_name='debugger.TensorProto.dims', index=3, + number=4, type=3, cpp_type=2, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='data_type', full_name='debugger.TensorProto.data_type', index=4, + number=5, type=14, cpp_type=8, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='finished', full_name='debugger.TensorProto.finished', index=5, + number=6, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='iter', full_name='debugger.TensorProto.iter', index=6, + number=7, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='truncate', full_name='debugger.TensorProto.truncate', index=7, + number=8, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=2200, + serialized_end=2373, +) + +_VALUEPROTO.fields_by_name['dtype'].enum_type = _DATATYPE +_VALUEPROTO.fields_by_name['tensor_val'].message_type = _TENSORPROTO +_VALUEPROTO.fields_by_name['graph'].message_type = _GRAPHPROTO +_VALUEPROTO.fields_by_name['tensor_vals'].message_type = _TENSORPROTO +_VALUEPROTO.fields_by_name['graphs'].message_type = _GRAPHPROTO +_VALUEPROTO.fields_by_name['values'].message_type = _VALUEPROTO +_VALUEPROTO.fields_by_name['dict_val'].message_type = _NAMEDVALUEPROTO +_VALUEPROTO.fields_by_name['type_val'].message_type = _TYPEPROTO +_ATTRIBUTEPROTO.fields_by_name['value'].message_type = _VALUEPROTO +_NAMEDVALUEPROTO.fields_by_name['value'].message_type = _VALUEPROTO +_TENSORSHAPEPROTO_DIMENSION.containing_type = _TENSORSHAPEPROTO +_TENSORSHAPEPROTO.fields_by_name['dim'].message_type = _TENSORSHAPEPROTO_DIMENSION +_TYPEPROTO_TENSOR.fields_by_name['elem_type'].enum_type = _DATATYPE +_TYPEPROTO_TENSOR.fields_by_name['shape'].message_type = _TENSORSHAPEPROTO +_TYPEPROTO_TENSOR.containing_type = _TYPEPROTO +_TYPEPROTO_SEQUENCE.fields_by_name['elem_types'].message_type = _TYPEPROTO +_TYPEPROTO_SEQUENCE.containing_type = _TYPEPROTO +_TYPEPROTO.fields_by_name['data_type'].enum_type = _DATATYPE +_TYPEPROTO.fields_by_name['tensor_type'].message_type = _TYPEPROTO_TENSOR +_TYPEPROTO.fields_by_name['sequence_type'].message_type = _TYPEPROTO_SEQUENCE +_TYPEPROTO.oneofs_by_name['value'].fields.append( + _TYPEPROTO.fields_by_name['tensor_type']) +_TYPEPROTO.fields_by_name['tensor_type'].containing_oneof = _TYPEPROTO.oneofs_by_name['value'] +_TYPEPROTO.oneofs_by_name['value'].fields.append( + _TYPEPROTO.fields_by_name['sequence_type']) +_TYPEPROTO.fields_by_name['sequence_type'].containing_oneof = _TYPEPROTO.oneofs_by_name['value'] +_PARAMETERPROTO.fields_by_name['type'].message_type = _TYPEPROTO +_PARAMETERPROTO.fields_by_name['default_val'].message_type = _VALUEPROTO +_OUTPUTPROTO.fields_by_name['type'].message_type = _TYPEPROTO +_INPUTPROTO.fields_by_name['type'].enum_type = _INPUTPROTO_EDGETYPE +_INPUTPROTO_EDGETYPE.containing_type = _INPUTPROTO +_NODEPROTO.fields_by_name['input'].message_type = _INPUTPROTO +_NODEPROTO.fields_by_name['attribute'].message_type = _ATTRIBUTEPROTO +_NODEPROTO.fields_by_name['output_type'].message_type = _TYPEPROTO +_MODELPROTO.fields_by_name['graph'].message_type = _GRAPHPROTO +_MODELPROTO.fields_by_name['metadata_operators'].message_type = _OPERATORSETPROTO +_OPERATORSETPROTO.fields_by_name['operators'].message_type = _OPERATORPROTO +_GRAPHPROTO.fields_by_name['node'].message_type = _NODEPROTO +_GRAPHPROTO.fields_by_name['parameters'].message_type = _PARAMETERPROTO +_GRAPHPROTO.fields_by_name['outputs'].message_type = _OUTPUTPROTO +_GRAPHPROTO.fields_by_name['const_vals'].message_type = _NAMEDVALUEPROTO +_TENSORPROTO.fields_by_name['data_type'].enum_type = _DATATYPE +DESCRIPTOR.message_types_by_name['ValueProto'] = _VALUEPROTO +DESCRIPTOR.message_types_by_name['AttributeProto'] = _ATTRIBUTEPROTO +DESCRIPTOR.message_types_by_name['NamedValueProto'] = _NAMEDVALUEPROTO +DESCRIPTOR.message_types_by_name['TensorShapeProto'] = _TENSORSHAPEPROTO +DESCRIPTOR.message_types_by_name['TypeProto'] = _TYPEPROTO +DESCRIPTOR.message_types_by_name['ParameterProto'] = _PARAMETERPROTO +DESCRIPTOR.message_types_by_name['OutputProto'] = _OUTPUTPROTO +DESCRIPTOR.message_types_by_name['InputProto'] = _INPUTPROTO +DESCRIPTOR.message_types_by_name['NodeProto'] = _NODEPROTO +DESCRIPTOR.message_types_by_name['ModelProto'] = _MODELPROTO +DESCRIPTOR.message_types_by_name['OperatorProto'] = _OPERATORPROTO +DESCRIPTOR.message_types_by_name['OperatorSetProto'] = _OPERATORSETPROTO +DESCRIPTOR.message_types_by_name['GraphProto'] = _GRAPHPROTO +DESCRIPTOR.message_types_by_name['TensorProto'] = _TENSORPROTO +DESCRIPTOR.enum_types_by_name['Version'] = _VERSION +DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +ValueProto = _reflection.GeneratedProtocolMessageType('ValueProto', (_message.Message,), { + 'DESCRIPTOR' : _VALUEPROTO, + '__module__' : 'mindinsight.debugger.proto.ms_graph_pb2' + # @@protoc_insertion_point(class_scope:debugger.ValueProto) + }) +_sym_db.RegisterMessage(ValueProto) + +AttributeProto = _reflection.GeneratedProtocolMessageType('AttributeProto', (_message.Message,), { + 'DESCRIPTOR' : _ATTRIBUTEPROTO, + '__module__' : 'mindinsight.debugger.proto.ms_graph_pb2' + # @@protoc_insertion_point(class_scope:debugger.AttributeProto) + }) +_sym_db.RegisterMessage(AttributeProto) + +NamedValueProto = _reflection.GeneratedProtocolMessageType('NamedValueProto', (_message.Message,), { + 'DESCRIPTOR' : _NAMEDVALUEPROTO, + '__module__' : 'mindinsight.debugger.proto.ms_graph_pb2' + # @@protoc_insertion_point(class_scope:debugger.NamedValueProto) + }) +_sym_db.RegisterMessage(NamedValueProto) + +TensorShapeProto = _reflection.GeneratedProtocolMessageType('TensorShapeProto', (_message.Message,), { + + 'Dimension' : _reflection.GeneratedProtocolMessageType('Dimension', (_message.Message,), { + 'DESCRIPTOR' : _TENSORSHAPEPROTO_DIMENSION, + '__module__' : 'mindinsight.debugger.proto.ms_graph_pb2' + # @@protoc_insertion_point(class_scope:debugger.TensorShapeProto.Dimension) + }) + , + 'DESCRIPTOR' : _TENSORSHAPEPROTO, + '__module__' : 'mindinsight.debugger.proto.ms_graph_pb2' + # @@protoc_insertion_point(class_scope:debugger.TensorShapeProto) + }) +_sym_db.RegisterMessage(TensorShapeProto) +_sym_db.RegisterMessage(TensorShapeProto.Dimension) + +TypeProto = _reflection.GeneratedProtocolMessageType('TypeProto', (_message.Message,), { + + 'Tensor' : _reflection.GeneratedProtocolMessageType('Tensor', (_message.Message,), { + 'DESCRIPTOR' : _TYPEPROTO_TENSOR, + '__module__' : 'mindinsight.debugger.proto.ms_graph_pb2' + # @@protoc_insertion_point(class_scope:debugger.TypeProto.Tensor) + }) + , + + 'Sequence' : _reflection.GeneratedProtocolMessageType('Sequence', (_message.Message,), { + 'DESCRIPTOR' : _TYPEPROTO_SEQUENCE, + '__module__' : 'mindinsight.debugger.proto.ms_graph_pb2' + # @@protoc_insertion_point(class_scope:debugger.TypeProto.Sequence) + }) + , + 'DESCRIPTOR' : _TYPEPROTO, + '__module__' : 'mindinsight.debugger.proto.ms_graph_pb2' + # @@protoc_insertion_point(class_scope:debugger.TypeProto) + }) +_sym_db.RegisterMessage(TypeProto) +_sym_db.RegisterMessage(TypeProto.Tensor) +_sym_db.RegisterMessage(TypeProto.Sequence) + +ParameterProto = _reflection.GeneratedProtocolMessageType('ParameterProto', (_message.Message,), { + 'DESCRIPTOR' : _PARAMETERPROTO, + '__module__' : 'mindinsight.debugger.proto.ms_graph_pb2' + # @@protoc_insertion_point(class_scope:debugger.ParameterProto) + }) +_sym_db.RegisterMessage(ParameterProto) + +OutputProto = _reflection.GeneratedProtocolMessageType('OutputProto', (_message.Message,), { + 'DESCRIPTOR' : _OUTPUTPROTO, + '__module__' : 'mindinsight.debugger.proto.ms_graph_pb2' + # @@protoc_insertion_point(class_scope:debugger.OutputProto) + }) +_sym_db.RegisterMessage(OutputProto) + +InputProto = _reflection.GeneratedProtocolMessageType('InputProto', (_message.Message,), { + 'DESCRIPTOR' : _INPUTPROTO, + '__module__' : 'mindinsight.debugger.proto.ms_graph_pb2' + # @@protoc_insertion_point(class_scope:debugger.InputProto) + }) +_sym_db.RegisterMessage(InputProto) + +NodeProto = _reflection.GeneratedProtocolMessageType('NodeProto', (_message.Message,), { + 'DESCRIPTOR' : _NODEPROTO, + '__module__' : 'mindinsight.debugger.proto.ms_graph_pb2' + # @@protoc_insertion_point(class_scope:debugger.NodeProto) + }) +_sym_db.RegisterMessage(NodeProto) + +ModelProto = _reflection.GeneratedProtocolMessageType('ModelProto', (_message.Message,), { + 'DESCRIPTOR' : _MODELPROTO, + '__module__' : 'mindinsight.debugger.proto.ms_graph_pb2' + # @@protoc_insertion_point(class_scope:debugger.ModelProto) + }) +_sym_db.RegisterMessage(ModelProto) + +OperatorProto = _reflection.GeneratedProtocolMessageType('OperatorProto', (_message.Message,), { + 'DESCRIPTOR' : _OPERATORPROTO, + '__module__' : 'mindinsight.debugger.proto.ms_graph_pb2' + # @@protoc_insertion_point(class_scope:debugger.OperatorProto) + }) +_sym_db.RegisterMessage(OperatorProto) + +OperatorSetProto = _reflection.GeneratedProtocolMessageType('OperatorSetProto', (_message.Message,), { + 'DESCRIPTOR' : _OPERATORSETPROTO, + '__module__' : 'mindinsight.debugger.proto.ms_graph_pb2' + # @@protoc_insertion_point(class_scope:debugger.OperatorSetProto) + }) +_sym_db.RegisterMessage(OperatorSetProto) + +GraphProto = _reflection.GeneratedProtocolMessageType('GraphProto', (_message.Message,), { + 'DESCRIPTOR' : _GRAPHPROTO, + '__module__' : 'mindinsight.debugger.proto.ms_graph_pb2' + # @@protoc_insertion_point(class_scope:debugger.GraphProto) + }) +_sym_db.RegisterMessage(GraphProto) + +TensorProto = _reflection.GeneratedProtocolMessageType('TensorProto', (_message.Message,), { + 'DESCRIPTOR' : _TENSORPROTO, + '__module__' : 'mindinsight.debugger.proto.ms_graph_pb2' + # @@protoc_insertion_point(class_scope:debugger.TensorProto) + }) +_sym_db.RegisterMessage(TensorProto) + + +# @@protoc_insertion_point(module_scope) diff --git a/mindinsight/debugger/stream_cache/__init__.py b/mindinsight/debugger/stream_cache/__init__.py new file mode 100644 index 00000000..e3077430 --- /dev/null +++ b/mindinsight/debugger/stream_cache/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/mindinsight/debugger/stream_cache/debugger_graph.py b/mindinsight/debugger/stream_cache/debugger_graph.py new file mode 100644 index 00000000..2f13e335 --- /dev/null +++ b/mindinsight/debugger/stream_cache/debugger_graph.py @@ -0,0 +1,289 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""This file is used to define the basic graph.""" +from collections import deque + +from mindinsight.datavisual.data_transform.graph.msgraph import MSGraph +from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum +from mindinsight.debugger.common.exceptions.exceptions import \ + DebuggerNodeNotInGraphError, DebuggerParamValueError +from mindinsight.debugger.common.log import logger as log +from .node import NodeTree + + +class DebuggerGraph(MSGraph): + """The `DebuggerGraph` object provides interfaces to describe a debugger graph.""" + def __init__(self): + super(DebuggerGraph, self).__init__() + self._node_tree = None + + def get_node_name_by_full_name(self, full_name): + """Get node name by full names.""" + inner_name = self._full_name_map_name.get(full_name, '') + if not inner_name: + log.warning("Node %s does not find the relative inner node name.", full_name) + + return inner_name + + def get_full_name_by_node_name(self, node_name): + """Get full name by node name for leaf nodes.""" + node = self._normal_node_map.get(node_name) + if not node: + log.warning("Node %s is not leaf node.", node_name) + + return node.full_name if node else '' + + def get_nodes(self, searched_node_list): + """ + Search node names by a given pattern. + + Args: + searched_node_list (list[Node]): A list of leaf nodes that + matches the given search pattern. + + Returns: + A list of dict including the searched nodes. + [{ + "name": "Default", + "type": "name_scope", + "nodes": [{ + "name": "Default/Conv2D1", + "type": "name_scope", + "nodes": [{ + ... + }] + }] + }, + { + "name": "Gradients", + "type": "name_scope", + "nodes": [{ + "name": "Gradients/Default", + "type": "name_scope", + "nodes": [{ + ... + }] + }] + """ + # save the node in the NodeTree + self._node_tree = NodeTree() + for node in searched_node_list: + self._build_node_tree(node.name, node.type) + + # get the searched nodes in the NodeTree and reorganize them + searched_list = [] + self._traverse_node_tree(self._node_tree, searched_list) + + return searched_list + + def search_nodes_by_pattern(self, pattern): + """ + Search node names by a given pattern. + + Args: + pattern (Union[str, None]): The pattern of the node to search, + if None, return all node names. + + Returns: + list[(str, str)], a list of tuple (node name, node type). + """ + if pattern is not None: + pattern = pattern.lower() + searched_nodes = [ + node for name, node in self._leaf_nodes.items() + if pattern in name.lower() + ] + else: + searched_nodes = [node for name, node in self._leaf_nodes.items()] + return searched_nodes + + def _build_node_tree(self, node_name, node_type): + """Build node tree.""" + scope_names = node_name.split('/') + cur_node = self._node_tree + for scope_name in scope_names[:-1]: + sub_node = cur_node.get(scope_name) + if not sub_node: + sub_node = cur_node.add(scope_name) + cur_node = sub_node + cur_node.add(scope_names[-1], node_type) + + def _traverse_node_tree(self, cur_node, search_node_list): + """Traverse the watch nodes and update the total watched node list.""" + if not cur_node.get_children(): + return + for _, sub_node in cur_node.get_children(): + sub_nodes = [] + self._traverse_node_tree(sub_node, sub_nodes) + sub_node_dict = { + 'name': sub_node.node_name, + 'type': sub_node.node_type, + 'nodes': sub_nodes + } + search_node_list.append(sub_node_dict) + + def get_node_type(self, node_name): + """ + Get the type of the node. + + Args: + node_name (str): The full name of the node with its scope. + + Returns: + A string, leaf or name_scope. + """ + if node_name and not self.exist_node(name=node_name): + raise DebuggerNodeNotInGraphError(node_name=node_name) + + node = self._leaf_nodes.get(node_name) + if node is not None: + node_type = node.type + else: + node_type = NodeTypeEnum.NAME_SCOPE.value + + return node_type + + def get_tensor_history(self, node_name, depth=0): + """ + Get the tensor history of a specified node. + + Args: + node_name (str): The debug name of the node. + depth (int): The number of layers the user wants to trace. Default is 0. + + Returns: + list, a list of the traced tensors' name and node type, + arranged in order from leaf node to root node. + int, the number of output tensors. + """ + node = self._leaf_nodes.get(node_name) + tensor_history = self._get_tensor_infos_of_node(node) + cur_outputs_nums = len(tensor_history) + cur_depth = 0 + trace_list = deque([(node, cur_depth)]) + while trace_list: + cur_node, cur_depth = trace_list.popleft() + tensors_info = self._get_input_tensors_of_node(cur_node) + if tensors_info: + tensor_history.extend(tensors_info) + if cur_depth < depth: + for name in cur_node.input.keys(): + trace_list.append((self._leaf_nodes[name], cur_depth + 1)) + + return tensor_history, cur_outputs_nums + + @staticmethod + def _get_tensor_infos_of_node(cur_node, slot=None): + """Get tensors info of specified node.""" + tensors_info = [] + if slot is None: + slots = range(cur_node.output_nums) + elif slot >= 0: + slots = [slot] + else: + log.info("Skip get tensor info for %s:%s.", cur_node.name, slot) + return tensors_info + for num in slots: + tensor_info = { + 'name': cur_node.name + ':' + str(num), + 'full_name': cur_node.full_name + ':' + str(num), + 'node_type': cur_node.type + } + tensors_info.append(tensor_info) + + return tensors_info + + def _get_input_tensors_of_node(self, cur_node): + """Get input tensors of node.""" + tensors_info = [] + for name in cur_node.input.keys(): + node = self._leaf_nodes.get(name) + tensor_info = self._get_tensor_infos_of_node(node) + tensors_info.extend(tensor_info) + + return tensors_info + + def get_bfs_order(self): + """ + Traverse the graph in order of breath-first search. + + Returns: + list, including the leaf nodes arranged in BFS order. + """ + root = self.get_default_root() + log.info('Randomly choose node %s as root to do BFS.', root.name) + + bfs_order = [] + self.get_bfs_graph(root.name, bfs_order) + length = len(self._leaf_nodes.keys()) + # Find rest un-traversed nodes + for node_name, _ in self._leaf_nodes.items(): + if node_name not in bfs_order: + self.get_bfs_graph(node_name, bfs_order) + + if len(bfs_order) != length: + log.error("The length of bfs and leaf nodes are not equal.") + msg = "Not all nodes are traversed!" + raise DebuggerParamValueError(msg) + + return bfs_order + + def get_bfs_graph(self, node_name, bfs_order): + """ + Traverse the graph in order of breath-first search. + + Returns: + list, including the leaf nodes arranged in BFS order. + """ + temp_list = deque() + temp_list.append(node_name) + while temp_list: + node_name = temp_list.popleft() + node = self._leaf_nodes.get(node_name) + + if not node: + log.warning('Cannot find node %s in graph. Ignored.', node_name) + continue + + bfs_order.append(node_name) + if node.input: + for name in node.input.keys(): + if name not in temp_list and name not in bfs_order: + temp_list.append(name) + if node.output: + for name in node.output.keys(): + if name not in temp_list and name not in bfs_order: + temp_list.append(name) + + def get_default_root(self): + """ + Get a node as default root for BFS in graph. Using the + leaf node with the smallest node id as the default root. + + Returns: + str, the name of the default root. + """ + default_root = None + for _, item in self._leaf_nodes.items(): + if item.node_id == '1': + default_root = item + break + + if default_root is None: + log.error("Abnormal graph. Invalid node for BFS.") + msg = 'Abnormal graph. Invalid node for BFS.' + raise DebuggerParamValueError(msg) + + return default_root diff --git a/mindinsight/debugger/stream_cache/node.py b/mindinsight/debugger/stream_cache/node.py new file mode 100644 index 00000000..d4d95545 --- /dev/null +++ b/mindinsight/debugger/stream_cache/node.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +This file is used to define the node of graph and associated base types. +""" +from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError +from mindinsight.debugger.common.log import logger as log + + +class NodeTree: + """A class for building a node tree.""" + def __init__(self, node_name='', node_type=None): + self.node_name = node_name + self._node_type = node_type + self._children = {} + + @property + def node_type(self): + """The property of node type.""" + return self._node_type + + @node_type.setter + def node_type(self, value): + """Set the node type.""" + self._node_type = value + + def add(self, name, node_type=None): + """Add sub node.""" + sub_name = '/'.join([self.node_name, name]) if self.node_name else name + sub_node = NodeTree(sub_name, node_type) + self._children[name] = sub_node + return sub_node + + def get(self, sub_name): + """Get sub node.""" + return self._children.get(sub_name) + + def get_children(self): + """Get all childrens.""" + for name_scope, sub_node in self._children.items(): + yield name_scope, sub_node + + def remove(self, sub_name): + """Remove sub node.""" + try: + self._children.pop(sub_name) + except KeyError as err: + log.error("Failed to find node %s. %s", sub_name, err) + raise DebuggerParamValueError("Failed to find node {}".format(sub_name)) diff --git a/mindinsight/debugger/stream_cache/tensor.py b/mindinsight/debugger/stream_cache/tensor.py new file mode 100644 index 00000000..98920ae1 --- /dev/null +++ b/mindinsight/debugger/stream_cache/tensor.py @@ -0,0 +1,233 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""The definition of tensor stream.""" +from abc import abstractmethod, ABC + +import numpy as np + +from mindinsight.utils.tensor import TensorUtils +from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError +from mindinsight.debugger.common.log import logger as log +from mindinsight.debugger.common.utils import NUMPY_TYPE_MAP +from mindinsight.debugger.proto.ms_graph_pb2 import DataType + + +class BaseTensor(ABC): + """Tensor data structure.""" + + def __init__(self, step=0): + self._step = step + + @property + @abstractmethod + def name(self): + """The property of tensor name.""" + + @property + @abstractmethod + def dtype(self): + """The property of tensor dtype.""" + + @property + @abstractmethod + def shape(self): + """The property of tensor shape.""" + + @property + @abstractmethod + def value(self): + """The property of tensor shape.""" + + @abstractmethod + def get_tensor_value_by_shape(self, shape=None): + """Get tensor value by shape.""" + + def _to_dict(self): + """Get tensor info in dict format.""" + res = { + 'full_name': self.name, + 'step': self._step, + 'dtype': self.dtype, + 'shape': self.shape, + 'has_prev_step': False + } + return res + + def get_basic_info(self): + """Return basic info about tensor info.""" + if not self.shape: + value = self.value + else: + value = 'click to view' + res = self._to_dict() + res['value'] = value + return res + + def get_full_info(self, shape=None): + """Get tensor info with value.""" + res = self._to_dict() + value_info = self.get_tensor_serializable_value_by_shape(shape) + res.update(value_info) + return res + + +class OpTensor(BaseTensor): + """Tensor data structure for operator Node.""" + max_number_data_show_on_ui = 100000 + + def __init__(self, tensor_proto, step=0): + # the type of tensor_proto is TensorProto + super(OpTensor, self).__init__(step) + self._tensor_proto = tensor_proto + self._value = self.generate_value(tensor_proto) + + @property + def name(self): + """The property of tensor name.""" + node_name = self._tensor_proto.node_name + slot = self._tensor_proto.slot + return ':'.join([node_name, slot]) + + @property + def dtype(self): + """The property of tensor dtype.""" + tensor_type = DataType.Name(self._tensor_proto.data_type) + + return tensor_type + + @property + def shape(self): + """The property of tensor shape.""" + return list(self._tensor_proto.dims) + + @property + def value(self): + """The property of tensor value.""" + tensor_value = None + if self._value is not None: + tensor_value = self._value.tolist() + + return tensor_value + + @property + def numpy_value(self): + """The property of tensor value in numpy type.""" + return self._value + + def generate_value(self, tensor_proto): + """Generate tensor value from proto.""" + tensor_value = None + if tensor_proto.tensor_content: + tensor_value = tensor_proto.tensor_content + np_type = NUMPY_TYPE_MAP.get(self.dtype) + tensor_value = np.frombuffer(tensor_value, dtype=np_type) + tensor_value = tensor_value.reshape(self.shape) + return tensor_value + + def get_tensor_serializable_value_by_shape(self, shape=None): + """ + Get tensor value info by shape. + + Args: + shape (tuple): The specified range of tensor value. + + Returns: + dict, the specified tensor value and value statistics. + """ + tensor_value = self.get_tensor_value_by_shape(shape) + res = {} + if isinstance(tensor_value, np.ndarray): + statistics = TensorUtils.get_statistics_from_tensor(tensor_value) + res['statistics'] = TensorUtils.get_statistics_dict(statistics) + res['value'] = tensor_value.tolist() + return res + return res + + def get_tensor_value_by_shape(self, shape=None): + """ + Get tensor value by shape. + + Args: + shape (tuple): The specified shape. + + Returns: + Union[None, str, numpy.ndarray], the sub-tensor. + """ + if self._value is None: + log.warning("%s has no value yet.", self.name) + return None + if shape is None or not isinstance(shape, tuple): + log.info("Get the whole tensor value with shape is %s", shape) + return self._value + if len(shape) != len(self.shape): + log.error("Invalid shape. Received: %s, tensor shape: %s", shape, self.shape) + raise DebuggerParamValueError("Invalid shape. Shape unmatched.") + try: + value = self._value[shape] + except IndexError as err: + log.error("Invalid shape. Received: %s, tensor shape: %s", shape, self.shape) + log.exception(err) + raise DebuggerParamValueError("Invalid shape. Shape unmatched.") + if isinstance(value, np.ndarray): + if value.size > self.max_number_data_show_on_ui: + value = "Too large to show." + log.info("The tensor size is %s, which is too large to show on UI.") + else: + value = np.asarray(value) + return value + +class ConstTensor(BaseTensor): + """Tensor data structure for Const Node.""" + + def __init__(self, const_proto): + # the type of const_proto is NamedValueProto + super(ConstTensor, self).__init__() + self._const_proto = const_proto + + def set_step(self, step): + """Set step value.""" + self._step = step + + @property + def name(self): + """The property of tensor name.""" + return self._const_proto.key + ':0' + + @property + def dtype(self): + """The property of tensor dtype.""" + return DataType.Name(self._const_proto.value.dtype) + + @property + def shape(self): + """The property of tensor shape.""" + return [] + + @property + def value(self): + """The property of tensor shape.""" + fields = self._const_proto.value.ListFields() + if len(fields) != 2: + log.warning("Unexpected const proto <%s>.\n Please check offline.", self._const_proto) + for field_name, field_value in fields: + if field_name != 'dtype': + return field_value + return None + + def get_tensor_value_by_shape(self, shape=None): + """Get tensor info with value.""" + if shape is not None: + log.warning("Invalid shape for const value.") + return self.value diff --git a/mindinsight/debugger/stream_cache/watchpoint.py b/mindinsight/debugger/stream_cache/watchpoint.py new file mode 100644 index 00000000..fb267e56 --- /dev/null +++ b/mindinsight/debugger/stream_cache/watchpoint.py @@ -0,0 +1,300 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Define the watchpoint stream.""" +from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum +from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError +from mindinsight.debugger.common.log import logger as log +from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD, WatchCondition + +WATCHPOINT_CONDITION_MAPPING = { + 'INF': WatchCondition.Condition.inf, + 'NAN': WatchCondition.Condition.nan, + 'OVERFLOW': WatchCondition.Condition.overflow, + 'MAX_GT': WatchCondition.Condition.max_gt, + 'MAX_LT': WatchCondition.Condition.max_lt, + 'MIN_GT': WatchCondition.Condition.min_gt, + 'MIN_LT': WatchCondition.Condition.min_lt, + 'MAX_MIN_GT': WatchCondition.Condition.max_min_gt, + 'MAX_MIN_LT': WatchCondition.Condition.max_min_lt, + 'MEAN_GT': WatchCondition.Condition.mean_gt, + 'MEAN_LT': WatchCondition.Condition.mean_lt +} + + +class WatchNodeTree: + """The WatchNode Node Structure.""" + NOT_WATCH = 0 # the scope node and the nodes below are not watched + PARTIAL_WATCH = 1 # at least one node under the scope node is not watched + TOTAL_WATCH = 2 # the scope node and the nodes below are all watched + + def __init__(self, node_name='', node_type=None, full_name='', watch_status=1): + self._node_name = node_name + self._full_name = full_name + self._node_type = self._translate_node_type(node_type) + self._watch_status = watch_status + self._children = {} + + @property + def node_name(self): + """The property of node name.""" + return self._node_name + + @property + def full_name(self): + """The property of node name.""" + return self._full_name + + @property + def node_type(self): + """The property of node type.""" + return self._node_type + + @node_type.setter + def node_type(self, value): + """Set the node type.""" + self._node_type = self._translate_node_type(value) + + @property + def watch_status(self): + """The property of watch status about current node.""" + return self._watch_status + + def enable_watch_status(self): + """The property of watch status about current node.""" + self._watch_status = WatchNodeTree.TOTAL_WATCH + + @staticmethod + def _translate_node_type(node_type): + """Translate node type to watch node type.""" + if not node_type or node_type == NodeTypeEnum.NAME_SCOPE.value or \ + node_type == NodeTypeEnum.AGGREGATION_SCOPE.value: + return 'scope' + return 'leaf' + + def get(self, sub_name): + """Get sub node.""" + return self._children.get(sub_name) + + def get_children(self): + """Get all childrens.""" + for name_scope, sub_watch_node in self._children.items(): + yield name_scope, sub_watch_node + + def add_node(self, node_name, node_type, full_name=''): + """ + Add watch node to watch node tree. + + Args: + node_name (str): The node name. + node_type (str): The node type. + full_name (str): The full name of node. + """ + log.debug("Add node %s with type: %s, full_name: %s", node_name, node_type, full_name) + scope_names = node_name.split('/', 1) + if len(scope_names) == 1: + if not self.get(node_name): + self.add(node_name, node_type, full_name, watch_status=WatchNodeTree.TOTAL_WATCH) + else: + self.get(node_name).enable_watch_status() + return + + scope_name, sub_names = scope_names + sub_tree = self.get(scope_name) + if not sub_tree: + sub_tree = self.add(scope_name, watch_status=1) + sub_tree.add_node(sub_names, node_type, full_name) + + def add(self, name, node_type=None, full_name='', watch_status=1): + """Add sub WatchPointTree.""" + sub_name = '/'.join([self._node_name, name]) if self._node_name else name + sub_tree = WatchNodeTree(sub_name, node_type, full_name, watch_status) + self._children[name] = sub_tree + + return sub_tree + + def remove_node(self, node_name): + """Remove sub node from current tree.""" + log.debug("Remove %s", node_name) + scope_names = node_name.split('/', 1) + sub_tree_name = scope_names[0] + sub_tree = self._children.get(sub_tree_name) + if not sub_tree: + log.error("Failed to find node %s in WatchNodeTree.", sub_tree_name) + raise DebuggerParamValueError("Failed to find node {}".format(sub_tree_name)) + + if len(scope_names) > 1: + sub_tree.remove_node(scope_names[1]) + + if sub_tree.watch_status == WatchNodeTree.NOT_WATCH or len(scope_names) == 1: + self._children.pop(sub_tree_name) + + self._watch_status = WatchNodeTree.PARTIAL_WATCH if self._children else \ + WatchNodeTree.NOT_WATCH + + +class Watchpoint: + """ + The class of watchpoint stream. + + Args: + watchpoint_id (int): The id of Watchpoint. + watch_condition (dict): The condition of Watchpoint. + + - condition (str): Accept `INF` or `NAN`. + + - param (list[float]): Not defined yet. + """ + + def __init__(self, watchpoint_id, watch_condition): + self._id = watchpoint_id + self._condition = watch_condition + self._watch_node = WatchNodeTree() + + @property + def watchpoint_id(self): + """The property of watchpoint id.""" + return self._id + + @property + def nodes(self): + """The property of watch nodes.""" + return self._watch_node + + @property + def condition(self): + """The property of watch condition.""" + return self._condition + + def copy_nodes_from(self, other_watchpoint): + """ + Copy nodes from other watchpoint. + Args: + other_watchpoint (Watchpoint): Other watchpoint. + """ + self._watch_node = other_watchpoint.nodes + + def add_nodes(self, nodes): + """Add node into watchcpoint.""" + if not nodes: + log.warning("Add empty nodes.") + return + + if not isinstance(nodes, list): + nodes = [nodes] + for node in nodes: + self._watch_node.add_node(node.name, node.type, node.full_name) + + def remove_nodes(self, nodes): + """Remove nodes from watchpoint.""" + if not nodes: + return + if not isinstance(nodes, list): + nodes = [nodes] + for node in nodes: + node_name = node.split(':')[0] + self._watch_node.remove_node(node_name) + + def get_node_status(self, node_name, node_type, full_name): + """Judge if the node is in watch nodes.""" + scope_names = node_name.split('/') + cur_node = self._watch_node + status = 1 + for scope_name in scope_names: + cur_node = cur_node.get(scope_name) + if cur_node is None: + status = WatchNodeTree.NOT_WATCH + break + if cur_node.watch_status == WatchNodeTree.TOTAL_WATCH: + status = WatchNodeTree.TOTAL_WATCH + break + if status == WatchNodeTree.TOTAL_WATCH and cur_node.node_name != node_name: + self._watch_node.add_node(node_name, node_type, full_name) + + return status + + def get_watch_node(self, cur_watch_node, watch_node_list): + """ + Traverse the watch nodes and add total watched node list to `watch_node_list`. + + Args: + cur_watch_node (WatchNodeTree): The current watch node. + watch_node_list (list[WatchNodeTree]): The list of total watched node. + """ + if cur_watch_node.watch_status == WatchNodeTree.TOTAL_WATCH: + watch_node_list.append(cur_watch_node) + return + for _, watch_node in cur_watch_node.get_children(): + self.get_watch_node(watch_node, watch_node_list) + + def get_set_cmd(self): + """Return the watchpoint in proto format.""" + # get watch nodes. + watch_nodes = [] + self.get_watch_node(self._watch_node, watch_nodes) + # construct SetCMD + set_cmd = SetCMD() + set_cmd.id = self._id + set_cmd.delete = False + set_cmd.watch_condition.condition = WATCHPOINT_CONDITION_MAPPING.get( + self._condition.get('condition')) + if self._condition.get('param'): + # at most one param is provided + set_cmd.watch_condition.value = self._condition.get('param') + for watch_node in watch_nodes: + event_node = set_cmd.watch_nodes.add() + event_node.node_name = watch_node.full_name + event_node.node_type = watch_node.node_type + + return set_cmd + + def get_watch_condition_info(self): + """Get watch condition info.""" + watchpoint_info = { + 'id': self._id, + 'watch_condition': self._condition + } + return watchpoint_info + + +class WatchpointHit: + """The watchpoint hit structure.""" + + def __init__(self, tensor_proto, watchpoint, node_name): + self._node_name = node_name + self._full_name = tensor_proto.node_name + self._slot = tensor_proto.slot + self._watchpoint = watchpoint + + @property + def tensor_full_name(self): + """The property of tensor_name.""" + tensor_name = ':'.join([self._full_name, self._slot]) + return tensor_name + + @property + def tensor_name(self): + """The property of tensor_name.""" + return ':'.join([self._node_name, self._slot]) + + @property + def watchpoint(self): + """The property of watchpoint.""" + watchpoint = self._watchpoint.get_watch_condition_info() + return watchpoint + + def __eq__(self, other): + """Define the equal condition.""" + flag = self.tensor_full_name == other.tensor_full_name and self.watchpoint == other.watchpoint + return flag diff --git a/mindinsight/debugger/stream_handler/__init__.py b/mindinsight/debugger/stream_handler/__init__.py new file mode 100644 index 00000000..73bc29ab --- /dev/null +++ b/mindinsight/debugger/stream_handler/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Import the streams handlers.""" +from .event_handler import EventHandler +from .metadata_handler import MetadataHandler +from .graph_handler import GraphHandler +from .tensor_handler import TensorHandler +from .watchpoint_handler import WatchpointHandler, WatchpointHitHandler + +__all__ = ['EventHandler', 'MetadataHandler', 'GraphHandler', 'TensorHandler', + 'WatchpointHandler', 'WatchpointHitHandler'] diff --git a/mindinsight/debugger/stream_handler/base_handler.py b/mindinsight/debugger/stream_handler/base_handler.py new file mode 100644 index 00000000..d57c4442 --- /dev/null +++ b/mindinsight/debugger/stream_handler/base_handler.py @@ -0,0 +1,34 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Define the stream handler base.""" +from abc import abstractmethod + + +class StreamHandlerBase: + """The stream handler base.""" + + @abstractmethod + def put(self, value): + """Abstract method of set data.""" + return NotImplementedError + + @abstractmethod + def get(self, filter_condition): + """Abstract method of get data.""" + return NotImplementedError + + def clean(self): + """Clean cache.""" + self.__init__() diff --git a/mindinsight/debugger/stream_handler/event_handler.py b/mindinsight/debugger/stream_handler/event_handler.py new file mode 100644 index 00000000..092453bd --- /dev/null +++ b/mindinsight/debugger/stream_handler/event_handler.py @@ -0,0 +1,159 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Define the message handler.""" +import uuid +from queue import Queue, Empty +from threading import Lock + +from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError +from mindinsight.debugger.common.log import logger as log +from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase + + +class EventHandler(StreamHandlerBase): + """Message Handler.""" + + max_limit = 1000 # the max number of items in cache + + def __init__(self): + self._prev_flag = str(uuid.uuid4()) + self._cur_flag = str(uuid.uuid4()) + self._next_idx = 0 + self._event_cache = [None] * self.max_limit + self._pending_requests = {} + self._lock = Lock() + + @property + def next_pos(self): + """The next pos to be updated in cache.""" + return ':'.join([self._cur_flag, str(self._next_idx)]) + + def has_pos(self, pos): + """Get the event according to pos.""" + cur_flag, cur_idx = self._parse_pos(pos) + event = self._event_cache[cur_idx] + if event is not None: + if not cur_flag or (cur_flag == self._cur_flag and cur_idx < self._next_idx) or \ + (cur_flag == self._prev_flag and cur_idx >= self._next_idx): + return event + + return None + + def clean(self): + """Clean event cache.""" + with self._lock: + self._prev_flag = str(uuid.uuid4()) + self._cur_flag = str(uuid.uuid4()) + self._next_idx = 0 + self._event_cache = [None] * self.max_limit + value = {'metadata': {'pos': '0'}} + self.clean_pending_requests(value) + log.debug("Clean event cache.") + + def put(self, value): + """ + Put value into event_cache. + + Args: + value (dict): The event to be put into cache. + """ + if not isinstance(value, dict): + log.error("Dict type required when put event message.") + raise DebuggerParamValueError("Dict type required when put event message.") + + with self._lock: + log.debug("Put the %d-th message into queue. \n %d requests is waiting.", + self._next_idx, len(self._pending_requests)) + cur_pos = self._next_idx + # update next pos + self._next_idx += 1 + if self._next_idx >= self.max_limit: + self._next_idx = 0 + self._prev_flag = self._cur_flag + self._cur_flag = str(uuid.uuid4()) + # set next pos + if not value.get('metadata'): + value['metadata'] = {} + value['metadata']['pos'] = self.next_pos + self._event_cache[cur_pos] = value + # feed the value for pending requests + self.clean_pending_requests(value) + + def clean_pending_requests(self, value): + """Clean pending requests.""" + for _, request in self._pending_requests.items(): + request.put(value) + self._pending_requests = {} + + def get(self, filter_condition=None): + """ + Get the pos-th value from event_cache according to filter_condition. + + Args: + filter_condition (str): The index of event in cache. Default: None. + + Returns: + object, the pos-th event. + """ + flag, idx = self._parse_pos(filter_condition) + cur_id = str(uuid.uuid4()) + with self._lock: + # reset the pos after the cache is re-initialized. + if not flag or flag not in [self._cur_flag, self._prev_flag]: + idx = 0 + # get event from cache immediately + if idx != self._next_idx and self._event_cache[idx]: + return self._event_cache[idx] + # wait for the event + cur_queue = Queue(maxsize=1) + self._pending_requests[cur_id] = cur_queue + # block until event has been received + event = self._wait_for_event(cur_id, cur_queue, filter_condition) + + return event + + def _parse_pos(self, pos): + """Get next pos according to input position.""" + elements = pos.split(':') + try: + idx = int(elements[-1]) + except ValueError: + log.error("Invalid index. The index in pos should be digit but get pos:%s", pos) + raise DebuggerParamValueError("Invalid pos.") + + if idx < 0 or idx >= self.max_limit: + log.error("Invalid index. The index in pos should between [0, %d)", self.max_limit) + raise DebuggerParamValueError(f"Invalid pos. {idx}") + flag = elements[0] if len(elements) == 2 else '' + + return flag, idx + + def _wait_for_event(self, cur_id, cur_queue, pos): + """Wait for the pos-th event.""" + try: + # set the timeout to 25 seconds which is less the the timeout limit from UI + event = cur_queue.get(timeout=25) + except Empty: + event = None + + if event is None: + with self._lock: + if self._pending_requests.get(cur_id): + self._pending_requests.pop(cur_id) + log.debug("Clean timeout request. Left pending requests: %d", + len(self._pending_requests)) + event = {'metadata': {'pos': pos}} + + return event diff --git a/mindinsight/debugger/stream_handler/graph_handler.py b/mindinsight/debugger/stream_handler/graph_handler.py new file mode 100644 index 00000000..adb31d25 --- /dev/null +++ b/mindinsight/debugger/stream_handler/graph_handler.py @@ -0,0 +1,314 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Define the graph stream handler.""" +from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ + DebuggerNodeNotInGraphError, DebuggerGraphNotExistError +from mindinsight.debugger.common.log import logger as log +from mindinsight.debugger.stream_cache.debugger_graph import DebuggerGraph +from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase + + +class GraphHandler(StreamHandlerBase): + """Metadata Handler.""" + + def __init__(self): + self._graph_proto = None + self._graph = None + self._searched_node_list = [] + self.bfs_order = [] + + @property + def graph(self): + """The property of graph.""" + return self._graph_proto + + def put(self, value): + """ + Put value into graph cache. Called by grpc server. + + Args: + value (GraphProto): The Graph proto message. + """ + self._graph_proto = value + log.info("Put graph into cache.") + + # build graph + graph = DebuggerGraph() + graph.build_graph(value) + self._graph = graph + self.bfs_order = self._graph.get_bfs_order() + + def get(self, filter_condition=None): + """ + Get the graph of specific node. + + Args: + filter_condition (dict): + + - name (str): The full debug node name. + + - single_node (bool): If True, return the graph from root + to the specific node; else, return the sublayer of the + graph. Default: False. + + Returns: + dict, the metadata. + """ + try: + self._graph_exists() + except DebuggerGraphNotExistError: + log.warning('The graph is empty. To view a graph, ' + 'please start the training script first.') + return {'graph': {}} + + if filter_condition is None: + filter_condition = {} + single_node = filter_condition.get('single_node', False) + name = filter_condition.get('name') + + graph = {} + if single_node is True: + nodes = self.get_single_node(name) + else: + nodes = self.list_nodes(name) + graph.update(nodes) + + return {'graph': graph} + + def get_tensor_history(self, node_name, depth=0): + """ + Get the tensor history of a specified node. + + Args: + node_name (str): The debug name of the node. + depth (int): The number of layers the user + wants to trace. Default is 0. + + Returns: + dict, basic tensor history, only including tensor name and tensor type and node type. + """ + self._graph_exists() + if not self._graph.exist_node(node_name): + raise DebuggerNodeNotInGraphError(node_name) + + tensor_history, cur_outputs_nums = self._graph.get_tensor_history( + node_name, depth + ) + # add the tensor type for tensor history + self._update_tensor_history(tensor_history[0:cur_outputs_nums], 'output') + self._update_tensor_history(tensor_history[cur_outputs_nums:], 'input') + log.debug("Get %d tensors in tensor history for node <%s>.", len(tensor_history), node_name) + return {'tensor_history': tensor_history} + + @staticmethod + def _update_tensor_history(tensor_history, tensor_type): + """ + Add tensor source type for tensor history. + + Args: + tensor_history (list[dict]): Tensor history from Graph stream. Each element has two + keys: `node_type` and `name`. `node_type` refers to the type of the node which + the tensor come from. `name` refers to the tensor name. + tensor_type (str): The source type of the tensor. `input` or `output`. + """ + for single_tensor_info in tensor_history: + single_tensor_info['type'] = tensor_type + + def search_nodes(self, pattern): + """ + Search nodes by given pattern. + + Args: + pattern (Union[str, None]): The pattern of the node to search, + if None, return all node names. + + Returns: + dict, the searched node. + """ + self._graph_exists() + self._searched_node_list = self._graph.search_nodes_by_pattern(pattern) + nodes = self._graph.get_nodes(self._searched_node_list) + + return {'nodes': nodes} + + def get_node_names(self, pattern=None): + """Get graph nodes according to pattern.""" + return self._graph.search_nodes_by_pattern(pattern) + + def get_searched_node_list(self): + """Get searched node list.""" + return self._searched_node_list + + def get_node_type(self, node_name): + """ + Get the type of the specified node. + + Args: + node_name (str): The debug name of the node. + + Returns: + A string of the node type, name_scope or leaf. + """ + self._graph_exists() + node_type = self._graph.get_node_type(node_name) + + return node_type + + def get_full_name(self, node_name): + """Get full name according to ui node name.""" + full_name = self._graph.get_full_name_by_node_name(node_name) if node_name else '' + return full_name + + def get_node_name_by_full_name(self, full_name): + """Get UI node name by full name.""" + if self._graph: + node_name = self._graph.get_node_name_by_full_name(full_name) + else: + node_name = '' + log.info("No graph received yet.") + return node_name + + def list_nodes(self, scope): + """ + Get the nodes of every layer in graph. + + Args: + scope (str): The name of a scope. + + Returns: + TypedDict('Nodes', {'nodes': list[Node]}), format is {'nodes': []}. + example: + { + "nodes" : [ + { + "attr" : + { + "index" : "i: 0\n" + }, + "input" : {}, + "name" : "input_tensor", + "output" : + { + "Default/TensorAdd-op17" : + { + "edge_type" : "data", + "scope" : "name_scope", + "shape" : [1, 16, 128, 128] + } + }, + "output_i" : -1, + "proxy_input" : {}, + "proxy_output" : {}, + "independent_layout" : False, + "subnode_count" : 0, + "type" : "Data" + } + ] + } + """ + if scope and not self._graph.exist_node(scope): + raise DebuggerNodeNotInGraphError(node_name=scope) + + nodes = self._graph.list_node_by_scope(scope=scope) + return {'nodes': nodes} + + def get_node_by_bfs_order(self, node_name=None, ascend=True): + """ + Traverse the graph in order of breath-first search by given node. + + Args: + node_name (str): The node name which will be regarded + as the start node in graph. + ascend (bool): If True, traverse the input nodes; + If False, traverse the output nodes. Default is True. + + Returns: + dict, including the searched node and its tensor value. + """ + self._graph_exists() + bfs_order = self.bfs_order + length = len(bfs_order) + + if not bfs_order: + log.error('Cannot get the BFS order of the graph!') + msg = 'Cannot get the BFS order of the graph!' + raise DebuggerParamValueError(msg) + + if node_name is None: + if ascend is False: + next_node = bfs_order[-1] + else: + next_node = bfs_order[0] + else: + try: + index = bfs_order.index(node_name) + log.debug("The index of the node in BFS list is: %d", index) + except ValueError as err: + log.error('Cannot find the node: %s. Please check ' + 'the node name: %s', node_name, err) + msg = f'Cannot find the node: {node_name}. ' \ + f'Please check the node name {err}.' + raise DebuggerParamValueError(msg) + + next_node = self.get_next_node_in_bfs(index, length, ascend) + + return next_node + + def get_next_node_in_bfs(self, index, length, ascend): + """ + Get the next node in bfs order. + + Args: + index (int): The current index. + length (int): The number of all leaf nodes. + ascend (bool): Whether get the node in ascend order or not. + + Returns: + Union[None, dict], the next node object in dict type or None. + """ + next_node = None + if 0 <= index < length: + if ascend is True and index < length - 1: + next_node = self.bfs_order[index + 1] + elif ascend is False and index > 0: + next_node = self.bfs_order[index - 1] + + return next_node + + def get_single_node(self, name): + """ + Search node, and return every layer nodes until this node. + + Args: + name (str): The name of node. + + Returns: + dict, every layer nodes until this node. + """ + nodes = self._graph.search_single_node(name) + + return nodes + + def _graph_exists(self): + """ + Check if the graph has been loaded in the debugger cache. + + Raises: + DebuggerGraphNotExistError: If the graph does not exist. + """ + if self._graph is None: + log.error('The graph does not exist. Please start the ' + 'training script and try again.') + raise DebuggerGraphNotExistError diff --git a/mindinsight/debugger/stream_handler/metadata_handler.py b/mindinsight/debugger/stream_handler/metadata_handler.py new file mode 100644 index 00000000..8eae8b0e --- /dev/null +++ b/mindinsight/debugger/stream_handler/metadata_handler.py @@ -0,0 +1,131 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Define the metadata stream handler.""" +from mindinsight.debugger.common.log import logger as log +from mindinsight.debugger.common.utils import ServerStatus +from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase + + +class MetadataHandler(StreamHandlerBase): + """Metadata Handler.""" + + def __init__(self): + self._state = ServerStatus.PENDING + self._device_name = "" + self._step = 0 + self._client_ip = "" + self._cur_node_name = "" + self._cur_full_name = "" + self._backend = "" + + @property + def device_name(self): + """The property of device name.""" + return self._device_name + + @property + def step(self): + """The property of current step.""" + return self._step + + @property + def node_name(self): + """The property of current node name.""" + return self._cur_node_name + + @node_name.setter + def node_name(self, node_name): + """The property of current node name.""" + self._cur_node_name = node_name + + @property + def full_name(self): + """The property of current node name.""" + return self._cur_full_name + + @property + def backend(self): + """The property of current backend.""" + return self._backend + + @property + def state(self): + """The property of state.""" + return self._state.value + + @state.setter + def state(self, value): + """ + Set the property of state. + + Args: + value (str): The new state. + """ + self._state = ServerStatus(value) + + @property + def client_ip(self): + """The property of client ip.""" + return self._client_ip + + @client_ip.setter + def client_ip(self, value): + """ + Set the property of client ip. + + Args: + value (str): The new ip. + """ + self._client_ip = str(value) + + def put(self, value): + """ + Put value into metadata cache. Called by grpc server. + + Args: + value (MetadataProto): The Metadata proto message. + """ + self._device_name = value.device_name.split(':')[0] + self._step = value.cur_step + self._cur_full_name = value.cur_node + self._backend = value.backend if value.backend else "Ascend" + log.debug("Put metadata into cache at the %d-th step.", self._step) + + def get(self, filter_condition=None): + """ + Get updated value. Called by main server. + + Args: + filter_condition (str): The filter property. + + Returns: + dict, the metadata. + """ + metadata = {} + if filter_condition is None: + metadata = { + 'state': self.state, + 'step': self.step, + 'device_name': self.device_name, + 'pos': '0', + 'ip': self.client_ip, + 'node_name': self.node_name, + 'backend': self.backend + } + else: + metadata[filter_condition] = getattr(self, filter_condition) if \ + hasattr(self, filter_condition) else '' + + return {'metadata': metadata} diff --git a/mindinsight/debugger/stream_handler/tensor_handler.py b/mindinsight/debugger/stream_handler/tensor_handler.py new file mode 100644 index 00000000..3d1dee02 --- /dev/null +++ b/mindinsight/debugger/stream_handler/tensor_handler.py @@ -0,0 +1,298 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Define the tensor stream handler.""" +import numpy as np + +from mindinsight.datavisual.data_transform.graph.node import NodeTypeEnum +from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError +from mindinsight.debugger.common.log import logger as log +from mindinsight.debugger.proto.ms_graph_pb2 import DataType +from mindinsight.debugger.stream_cache.tensor import OpTensor, ConstTensor +from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase +from mindinsight.utils.tensor import TensorUtils + + +class TensorHandler(StreamHandlerBase): + """Metadata Handler.""" + + def __init__(self): + self._const_vals = {} + self._tensors = {} + self._cur_step = 0 + + def put(self, value): + """ + Put value into tensor cache. Called by grpc server. + + Args: + value (dict): The Tensor proto message. + + - step (int): The current step of tensor. + + - tensor_protos (list[TensorProto]): The tensor proto. + """ + tensor_protos = value.get('tensor_protos') + merged_tensor = self._get_merged_tensor(tensor_protos) + step = value.get('step', 0) + if merged_tensor.iter and step > 0: + log.debug("Received previous tensor.") + step -= 1 + tensor = OpTensor(merged_tensor, step) + self._put_tensor_into_cache(tensor, step) + log.debug("Put tensor %s of step: %d, into cache", tensor.name, step) + + @staticmethod + def _get_merged_tensor(tensor_protos): + """ + Merged list of parsed tensor value into one. + + Args: + tensor_protos (list[TensorProto]): List of tensor proto. + + Returns: + TensorProto, merged tensor proto. + """ + merged_tensor = tensor_protos[-1] + if len(tensor_protos) > 1: + tensor_value = bytes() + for tensor_proto in tensor_protos: + if not tensor_proto.tensor_content: + log.warning("Doesn't find tensor value for %s:%s", + tensor_proto.node_name, tensor_proto.slot) + break + tensor_value += tensor_proto.tensor_content + merged_tensor.tensor_content = tensor_value + log.debug("Merge multi tensor values into one.") + return merged_tensor + + def _put_tensor_into_cache(self, tensor, step): + """ + Put tensor into cache. + + Args: + tensor (OpTensor): The tensor value. + """ + cache_tensor = self._tensors.get(tensor.name) + if cache_tensor is None: + cache_tensor = {} + self._tensors[tensor.name] = cache_tensor + cache_tensor[step] = tensor + + def put_const_vals(self, const_vals): + """ + Put const value into tensor cache. + + Args: + const_vals (list[NamedValueProto]): List of const values. + """ + for const_val in const_vals: + if not (const_val.value and const_val.key): + continue + if DataType.Name(const_val.value.dtype) == "DT_TENSOR": + tensor_proto = const_val.value.tensor_val + tensor_proto.node_name = const_val.key + tensor_proto.slot = '0' + const_tensor = OpTensor(tensor_proto) + else: + const_tensor = ConstTensor(const_val) + self._const_vals[const_tensor.name] = const_tensor + + def get(self, filter_condition=None): + """ + Get full tensor value. + + Args: + filter_condition (dict): Filter condition. + + - name (str): The name of tensor. + + - node_type (str): The type of the node. + + Returns: + dict, the tensor_value. + """ + name = filter_condition.get('name') + node_type = filter_condition.get('node_type') + shape = filter_condition.get('shape') + tensor = self._get_tensor(name, node_type) + if not tensor: + log.error("No tensor named %s", name) + raise DebuggerParamValueError("No tensor named {}".format(name)) + tensor_info = tensor.get_full_info(shape) + self._update_has_prev_step_field(tensor_info, name, node_type) + return {'tensor_value': tensor_info} + + def _get_tensor(self, tensor_name, node_type=None, step=None): + """ + Get tensor according to tensor name and node_type. + + Args: + tensor_name (str): Tensor name, format like `node_name:slot`. + node_type (str): Node type. + step (int): The step of tensor info. Default: None. Noe + + Returns: + Union[OPTensor, ConstTensor], the tensor object. + """ + if step is None: + step = self._cur_step + tensor = self._tensors.get(tensor_name, {}).get(step) + if not tensor and node_type == NodeTypeEnum.CONST.value: + const_name = tensor_name.rsplit('/', 1)[-1] + tensor = self._const_vals.get(const_name) + self._tensors[tensor_name] = {step: tensor} + + return tensor + + def _get_basic_info(self, tensor_name, node_type=None): + """Get the latest basic tensor info by tensor name.""" + tensor = self._get_tensor(tensor_name, node_type) + if tensor: + return tensor.get_basic_info() + + return None + + def get_tensor_history(self, tensor_names): + """Get tensor history for tensor names.""" + # only used by grpc server, could be remove later + tensor_infos = [] + for tensor_name in tensor_names: + tensor_info = self._get_basic_info(tensor_name) + tensor_infos.append(tensor_info) + + return {'tensor_history': tensor_infos} + + def update_tensor_history(self, tensor_history): + """ + Add tensor basic info in tensor_history. + + Args: + tensor_history (dict): Tensor history, including a list of tensor name and type. + + Returns: + list[dict], the list of tensor basic info cache. + """ + missed_tensors = [] + for tensor_info in tensor_history.get('tensor_history'): + tensor_name = tensor_info.get('full_name') + node_type = tensor_info.get('node_type') + basic_info = self._get_basic_info(tensor_name, node_type) + flag = self._update_has_prev_step_field(basic_info, tensor_name, node_type) + if flag is False: + missed_tensor = tensor_info.copy() + missed_tensor['iter'] = 'prev' + missed_tensors.append(missed_tensor) + log.debug("Add previous view cmd for %s", tensor_name) + # add `has_prev_step` field to tensor basic info. + if basic_info: + tensor_info.update(basic_info) + else: + missed_tensors.append(tensor_info) + log.debug("Add view cmd for %s", tensor_name) + + return missed_tensors + + def _update_has_prev_step_field(self, tensor_info, tensor_name, node_type): + """Update has_prev_step field in tensor info.""" + flag = None + if node_type == NodeTypeEnum.PARAMETER.value: + flag = self._has_prev_tensor_value(tensor_name) + if flag and tensor_info: + tensor_info['has_prev_step'] = True + return flag + + def _has_prev_tensor_value(self, tensor_name): + """ + Check if the tensor has valid value of previous step. + + Args: + tensor_name (str): Tensor name. + + Returns: + bool, whether the tensor has valid tensor value. + """ + flag = None + # check if the tensor has previous step value. + prev_step = self._cur_step - 1 + if prev_step < 0: + return flag + tensor = self._get_tensor(tensor_name, step=prev_step) + flag = bool(tensor and tensor.value) + return flag + + def get_tensor_value_by_name(self, tensor_name, prev=False): + """Get tensor value by name in numpy type.""" + cur_step = self._cur_step + step = cur_step - 1 if prev else cur_step + if step < 0: + log.warning("%d step has no previous value for tensor: %s", cur_step, tensor_name) + return None + tensor = self._get_tensor(tensor_name, step=step) + + return tensor + + def clean_tensors(self, cur_step): + """Clean the tensor cache.""" + self._cur_step = cur_step + expired_tensor = [] + for tensor_name, tensor in self._tensors.items(): + expired_step = [step for step in tensor.keys() if step <= cur_step - 2] + for step in expired_step: + tensor.pop(step) + if not tensor: + expired_tensor.append(tensor_name) + for tensor_name in expired_tensor: + self._tensors.pop(tensor_name) + self._tensors = {} + + def get_tensors_diff(self, tensor_name, shape, tolerance=0): + """ + Get tensor comparisons data for given name, detail, shape and tolerance. + + Args: + tensor_name (str): The name of tensor for cache. + shape (tuple): Specify concrete dimensions of shape. + tolerance (str): Specify tolerance of difference between current step tensor and previous + step tensor. Default value is 0. Its is a percentage. The boundary value is equal to + max(abs(min),abs(max)) * tolerance. The function of min and max is being used to + calculate the min value and max value of the result of the current step tensor subtract + the previous step tensor. If the absolute value of result is less than or equal to + boundary value, the result will set to be zero. + + Raises: + DebuggerParamValueError, If get current step node and previous step node failed. + + Returns: + dict, the retrieved data. + """ + curr_tensor = self.get_tensor_value_by_name(tensor_name) + prev_tensor = self.get_tensor_value_by_name(tensor_name, prev=True) + if not (curr_tensor and prev_tensor): + log.error("Get current step and previous step for this tensor name %s failed.", tensor_name) + raise DebuggerParamValueError(f"Get current step and previous step for this tensor name " + f"{tensor_name} failed.") + curr_tensor_slice = curr_tensor.get_tensor_value_by_shape(shape) + prev_tensor_slice = prev_tensor.get_tensor_value_by_shape(shape) + tensor_info = curr_tensor.get_basic_info() + if isinstance(curr_tensor_slice, np.ndarray) and isinstance(prev_tensor_slice, np.ndarray): + diff_tensor = TensorUtils.calc_diff_between_two_tensor(curr_tensor_slice, prev_tensor_slice, tolerance) + result = np.stack([prev_tensor_slice, curr_tensor_slice, diff_tensor], axis=-1) + tensor_info['diff'] = result.tolist() + stats = TensorUtils.get_statistics_from_tensor(diff_tensor) + tensor_info['statistics'] = TensorUtils.get_statistics_dict(stats) + del tensor_info['has_prev_step'] + del tensor_info['value'] + reply = {'tensor_value': tensor_info} + return reply diff --git a/mindinsight/debugger/stream_handler/watchpoint_handler.py b/mindinsight/debugger/stream_handler/watchpoint_handler.py new file mode 100644 index 00000000..8c4c0636 --- /dev/null +++ b/mindinsight/debugger/stream_handler/watchpoint_handler.py @@ -0,0 +1,333 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Define the watchpoint stream handler.""" +import numpy as np + +from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \ + DebuggerParamTypeError +from mindinsight.debugger.common.log import logger as log +from mindinsight.debugger.proto.debug_grpc_pb2 import SetCMD +from mindinsight.debugger.stream_cache.watchpoint import Watchpoint, WatchpointHit, \ + WATCHPOINT_CONDITION_MAPPING +from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase + + +class WatchpointHandler(StreamHandlerBase): + """watchpoint Handler.""" + + def __init__(self): + self._watchpoints = {} + self._deleted_watchpoints = [] + self._updated_watchpoints = {} + self._latest_id = 0 + + def put(self, value): + """ + Put Watchpoint into watchpoint handler. + + Args: + value (Watchpoint): The name of nodes that have been chosen. + """ + new_id = value.watchpoint_id + self._watchpoints[new_id] = value + self._updated_watchpoints[new_id] = value + self._latest_id = new_id + log.debug("Put watchpoint %d into cache.", new_id) + + def sync_set_cmd(self): + """Clean temp watchpoints.""" + self._deleted_watchpoints = [] + self._updated_watchpoints = {} + + def get_watchpoint_by_id(self, watchpoint_id): + """Get watchpoint by watchpoint id.""" + watchpoint = self._watchpoints.get(watchpoint_id) + if not watchpoint: + log.error("Invalid watchpoint id %d", watchpoint_id) + raise DebuggerParamValueError("Invalid watchpoint id {}".format(watchpoint_id)) + + return watchpoint + + def get(self, filter_condition=False): + """ + Get the watchpoints. + + Args: + filter_condition (bool): If True, get all watchpoints without nodes. If False, + get updated watchpoints in SetCMD proto format. Default: False. + + Returns: + dict, the watchpoints. + """ + reply = [] + if not filter_condition: + # get watch condition list + for _, watchpoint in self._watchpoints.items(): + watchpoint_info = watchpoint.get_watch_condition_info() + reply.append(watchpoint_info) + else: + # get updated watchpoint list + for _, watchpoint in self._updated_watchpoints.items(): + set_cmd = watchpoint.get_set_cmd() + reply.append(set_cmd) + reply.extend(self._deleted_watchpoints) + + log.debug("get the watch points with filter_condition:%s", filter_condition) + + return {'watch_points': reply} + + def set_watch_nodes(self, graph, watch_point_id): + """ + set watch nodes for graph. + + Args: + graph (dict): The graph with list of nodes. + watch_point_id (int): The id of watchpoint. + """ + if not (watch_point_id and graph): + return + self._validate_watchpoint_id(watch_point_id) + log.debug("add watch flags") + watchpoint = self._watchpoints.get(watch_point_id) + self._set_watch_status_recursively(graph, watchpoint) + + def _set_watch_status_recursively(self, graph, watchpoint): + """Set watch status to graph.""" + if not isinstance(graph, dict): + log.warning("The graph is not dict.") + return + if graph.get('children'): + self._set_watch_status_recursively(graph.get('children'), watchpoint) + + for node in graph.get('nodes', []): + if not isinstance(node, dict): + log.warning("The node is not dict.") + return + node_name = node.get('name') + if not node_name: + continue + flag = watchpoint.get_node_status(node_name, node.get('type'), node.get('full_name')) + node['watched'] = flag + if node.get('nodes'): + self._set_watch_status_recursively(node, watchpoint) + + def create_watchpoint(self, watch_condition, watch_nodes=None, watch_point_id=None): + """ + Create watchpoint. + Args: + watch_condition (dict): The watch condition. + + - condition (str): Accept `INF` or `NAN`. + + - param (list[float]): Not defined yet. + watch_nodes (list[NodeBasicInfo]): The list of node basic info. + watch_point_id (int): The id of watchpoint. + + Returns: + int, the new id of watchpoint. + """ + validate_watch_condition(watch_condition) + new_id = self._latest_id + 1 + watchpoint = Watchpoint(new_id, watch_condition) + if watch_nodes: + watchpoint.add_nodes(watch_nodes) + elif watch_point_id: + self._validate_watchpoint_id(watch_point_id) + watchpoint.copy_nodes_from(self._watchpoints.get(watch_point_id)) + self.put(watchpoint) + + return new_id + + def update_watchpoint(self, watch_point_id, watch_nodes, watched=False): + """ + Update watchpoint. + + Args: + watch_point_id (int): The id of watchpoint. + watch_nodes (list[str]): The list of node names. + watched (bool): The update operator on nodes. If False, remove nodes from watch nodes. + If True, add nodes to watch nodes. Default: False. + + Returns: + dict, empty response. + """ + self._validate_watchpoint_id(watch_point_id) + watchpoint = self._watchpoints.get(watch_point_id) + if watched: + watchpoint.add_nodes(watch_nodes) + else: + watchpoint.remove_nodes(watch_nodes) + self._updated_watchpoints[watch_point_id] = watchpoint + log.debug("Update watchpoint %d in cache.", watch_point_id) + + def delete_watchpoint(self, watch_point_id): + """ + Delete watchpoint. + + Args: + watch_point_id (int): The id of watchpoint. + + Returns: + dict, empty response. + """ + self._validate_watchpoint_id(watch_point_id) + self._watchpoints.pop(watch_point_id) + set_cmd = SetCMD() + set_cmd.id = watch_point_id + set_cmd.delete = True + self._deleted_watchpoints.append(set_cmd) + log.debug("Delete watchpoint %d in cache.", watch_point_id) + + def _validate_watchpoint_id(self, watch_point_id): + """Validate watchpoint id.""" + if watch_point_id and watch_point_id not in self._watchpoints: + log.error("Invalid watchpoint id: %d.", watch_point_id) + raise DebuggerParamValueError("Invalid watchpoint id: {}".format(watch_point_id)) + + +class WatchpointHitHandler(StreamHandlerBase): + """Watchpoint hit handler.""" + + def __init__(self): + self._hits = {} + + def put(self, value): + """ + Put value into watchpoint hit cache. Called by grpc server. + + Args: + value (dict): The watchpoint hit info. + + - tensor_proto (TensorProto): The message about hit tensor. + + - watchpoint (Watchpoint): The Watchpoint that a node hit. + """ + watchpoint_hit = WatchpointHit( + tensor_proto=value.get('tensor_proto'), + watchpoint=value.get('watchpoint'), + node_name=value.get('node_name') + ) + + node_name = value.get('node_name') + hit_tensors = self._hits.get(node_name) + if hit_tensors is None: + hit_tensors = [] + self._hits[node_name] = hit_tensors + if watchpoint_hit not in hit_tensors: + hit_tensors.append(watchpoint_hit) + + def get(self, filter_condition=None): + """ + Get watchpoint hit list. + + Args: + filter_condition (str): Get the watchpoint hit according to specifiled node name. + If not given, get all watchpoint hits. Default: None. + + Returns: + dict, the watchpoint hit list. + """ + if filter_condition is None: + log.debug("Get all watchpoint hit list.") + reply = self.get_watchpoint_hits() + else: + log.debug("Get the watchpoint for node: <%s>.", filter_condition) + reply = self._hits.get(filter_condition) + + return reply + + def get_watchpoint_hits(self): + """Return the list of watchpoint hits.""" + watch_point_hits = [] + for node_name, watchpoint_hits in self._hits.items(): + watch_points = [watchpoint_hit.watchpoint for watchpoint_hit in watchpoint_hits] + watch_point_hits.append({ + 'node_name': node_name, + 'watch_points': watch_points + }) + + return {'watch_point_hits': watch_point_hits} + + def _is_tensor_hit(self, tensor_name): + """Check if the tensor is record in hit cache.""" + node_name = tensor_name.split(':')[0] + watchpoint_hits = self.get(node_name) + if watchpoint_hits is None: + return False + + for watchpoint_hit in watchpoint_hits: + if tensor_name == watchpoint_hit.tensor_name: + return True + + return False + + def update_tensor_history(self, tensor_history): + """ + Add hit flag to tensor history. + + Args: + tensor_history (dict): The tensor history. + """ + if not self._hits: + return + + # add hit tensor names to `tensor_names` + for tensor_info in tensor_history.get('tensor_history'): + tensor_name = tensor_info['full_name'] + hit_flag = self._is_tensor_hit(tensor_name) + tensor_info['is_hit'] = hit_flag + + +def validate_watch_condition(watch_condition): + """Validate watch condition.""" + if not isinstance(watch_condition, dict): + log.error(" should be dict. %s received.", watch_condition) + raise DebuggerParamTypeError(" should be dict.") + # validate condition + condition = watch_condition.get('condition') + if condition not in WATCHPOINT_CONDITION_MAPPING.keys(): + log.error("Invalid watch condition. Acceptable values are <%s>.", + str(WATCHPOINT_CONDITION_MAPPING.keys())) + raise DebuggerParamValueError("Invalid watch condition value.") + # validate param + validate_watch_condition_params(watch_condition) + + +def validate_watch_condition_params(watch_condition): + """ + Validate watch condition parameters. + + Args: + watch_condition (dict): Watch condition. + + - condition (str): Condition type. Should be in WATCHPOINT_CONDITION_MAPPING. + + - param (list): Condition value. Should be given for comparison condition. The value will + be translated to np.float32. + """ + condition = watch_condition.get('condition') + param = watch_condition.get('param') + if condition in ['NAN', 'INF', 'OVERFLOW']: + if param: + log.error("No param is expected for %s condition.", condition) + raise DebuggerParamValueError("No param is expected.") + else: + if not isinstance(param, (float, int)): + log.error("Number param should be given for condition <%s>.", + condition) + raise DebuggerParamValueError("Number param should be given.") + if np.isinf(np.float32(param)): + log.error("Condition param should be float32.") + raise DebuggerParamValueError("The value of condition param should be within float32.") diff --git a/mindinsight/scripts/start.py b/mindinsight/scripts/start.py index 2a41eb1c..1e0d625d 100644 --- a/mindinsight/scripts/start.py +++ b/mindinsight/scripts/start.py @@ -14,19 +14,28 @@ # ============================================================================ """Start mindinsight service.""" +import argparse import os -import sys import re -import argparse +import sys from importlib import import_module import psutil from mindinsight.conf import settings from mindinsight.utils.command import BaseCommand +from mindinsight.utils.exceptions import PortNotAvailableError from mindinsight.utils.hook import HookUtils from mindinsight.utils.hook import init -from mindinsight.utils.exceptions import PortNotAvailableError + + +def str2bool(string): + """Convert str to bool""" + if string.lower() == 'false': + return False + if string.lower() == 'true': + return True + raise ValueError class ConfigAction(argparse.Action): @@ -146,6 +155,23 @@ class UrlPathPrefixAction(argparse.Action): setattr(namespace, self.dest, prefix) +class EnableDebuggerAction(argparse.Action): + """SSL certificate action class definition.""" + + def __call__(self, parser, namespace, values, option_string=None): + """ + Inherited __call__ method from argparse.Action. + + Args: + parser (ArgumentParser): Passed-in argument parser. + namespace (Namespace): Namespace object to hold arguments. + values (object): Argument values with type depending on argument definition. + option_string (str): Optional string for specific argument name. Default: None. + """ + enable_debugger = values + setattr(namespace, self.dest, enable_debugger) + + class Command(BaseCommand): """ Start mindinsight service. @@ -186,6 +212,14 @@ class Command(BaseCommand): Custom port ranging from %s to %s. Default value is %s. """ % (PortAction.MIN_PORT, PortAction.MAX_PORT, settings.PORT)) + parser.add_argument( + '--debugger_port', + type=int, + action=PortAction, + help=""" + Debugger port ranging from %s to %s. Default value is %s. + """ % (PortAction.MIN_PORT, PortAction.MAX_PORT, settings.DEBUGGER_PORT)) + parser.add_argument( '--url-path-prefix', type=str, @@ -197,6 +231,14 @@ class Command(BaseCommand): dot or double dots. Default value is ''. """) + parser.add_argument( + '--enable_debugger', + type=str2bool, + action=EnableDebuggerAction, + default=False, + help=""" + Enable debugger or not. + Dfault is False.""") for hook in HookUtils.instance().hooks(): hook.register_startup_arguments(parser) diff --git a/mindinsight/utils/constant.py b/mindinsight/utils/constant.py index 0fc704f0..49b34acd 100644 --- a/mindinsight/utils/constant.py +++ b/mindinsight/utils/constant.py @@ -33,6 +33,7 @@ class MindInsightModules(Enum): SCRIPTCONVERTER = 7 WIZARD = 9 OPTIMIZER = 10 + DEBUGGER = 11 class GeneralErrors(Enum): @@ -56,6 +57,10 @@ class LineageMgrErrors(Enum): """Enum definition for lineage errors.""" +class DebuggerErrors(Enum): + """Enum definition for debugger errors.""" + + class DataVisualErrors(Enum): """Enum definition for datavisual errors.""" RESTFUL_API_NOT_EXIST = 1 diff --git a/mindinsight/utils/tensor.py b/mindinsight/utils/tensor.py new file mode 100644 index 00000000..5a28bf69 --- /dev/null +++ b/mindinsight/utils/tensor.py @@ -0,0 +1,298 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tensor utils.""" + +import numpy as np + +from mindinsight.datavisual.utils.tools import to_int +from mindinsight.utils.exceptions import ParamValueError +from mindinsight.utils.exceptions import ParamTypeError +from mindinsight.utils.log import utils_logger as logger + +F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max + + +class Statistics: + """Statistics data class. + + Args: + max_value (float): max value of tensor data. + min_value (float): min value of tensor data. + avg_value (float): avg value of tensor data. + count (int): total count of tensor data. + nan_count (int): count of NAN. + neg_inf_count (int): count of negative INF. + pos_inf_count (int): count of positive INF. + """ + + def __init__(self, max_value=0, min_value=0, avg_value=0, + count=0, nan_count=0, neg_inf_count=0, pos_inf_count=0): + self._max = max_value + self._min = min_value + self._avg = avg_value + self._count = count + self._nan_count = nan_count + self._neg_inf_count = neg_inf_count + self._pos_inf_count = pos_inf_count + + @property + def max(self): + """Get max value of tensor.""" + return self._max + + @property + def min(self): + """Get min value of tensor.""" + return self._min + + @property + def avg(self): + """Get avg value of tensor.""" + return self._avg + + @property + def count(self): + """Get total count of tensor.""" + return self._count + + @property + def nan_count(self): + """Get count of NAN.""" + return self._nan_count + + @property + def neg_inf_count(self): + """Get count of negative INF.""" + return self._neg_inf_count + + @property + def pos_inf_count(self): + """Get count of positive INF.""" + return self._pos_inf_count + + +class TensorUtils: + """Tensor Utils class.""" + + @staticmethod + def validate_dims_format(dims): + """ + Validate correct of format of dimension parameter. + + Args: + dims (str): Dims of tensor. Its format is something like this "[0, 0, :, :]". + + Raises: + ParamValueError: If format of dims is not correct. + """ + if dims is not None: + if not isinstance(dims, str): + raise ParamTypeError(dims, str) + dims = dims.strip() + if not (dims.startswith('[') and dims.endswith(']')): + raise ParamValueError('The value: {} of dims must be ' + 'start with `[` and end with `]`.'.format(dims)) + for dim in dims[1:-1].split(','): + dim = dim.strip() + if dim == ":": + continue + if dim.startswith('-'): + dim = dim[1:] + if not dim.isdigit(): + raise ParamValueError('The value: {} of dims in the square brackets ' + 'must be int or `:`.'.format(dims)) + + @staticmethod + def convert_array_from_str_dims(dims, limit=0): + """ + Convert string of dims data to array. + + Args: + dims (str): Specify dims of tensor. + limit (int): The max flexible dimension count, default value is 0 which means that there is no limitation. + + Returns: + list, a string like this: "[0, 0, :, :]" will convert to this value: [0, 0, None, None]. + + Raises: + ParamValueError, If flexible dimensions exceed limit value. + """ + dims = dims.strip().lstrip('[').rstrip(']') + dims_list = [] + count = 0 + for dim in dims.split(','): + dim = dim.strip() + if dim == ':': + dims_list.append(None) + count += 1 + else: + dims_list.append(to_int(dim, "dim")) + if limit and count > limit: + raise ParamValueError("Flexible dimensions cannot exceed limit value: {}, size: {}" + .format(limit, count)) + return dims_list + + @staticmethod + def get_specific_dims_data(ndarray, dims, tensor_dims): + """ + Get specific dims data. + + Args: + ndarray (numpy.ndarray): An ndarray of numpy. + dims (list): A list of specific dims. + tensor_dims (list): A list of tensor dims. + + Returns: + numpy.ndarray, an ndarray of specific dims tensor data. + + Raises: + ParamValueError, If the length of param dims is not equal to the length of tensor dims or + the index of param dims out of range. + """ + if len(dims) != len(tensor_dims): + raise ParamValueError("The length of param dims: {}, is not equal to the " + "length of tensor dims: {}.".format(len(dims), len(tensor_dims))) + indices = [] + for k, d in enumerate(dims): + if d is not None: + if d >= tensor_dims[k]: + raise ParamValueError("The index: {} of param dims out of range: {}.".format(d, tensor_dims[k])) + indices.append(d) + else: + indices.append(slice(0, tensor_dims[k])) + result = ndarray[tuple(indices)] + # Make sure the return type is numpy.ndarray. + if not isinstance(result, np.ndarray): + result = np.array(result) + return result + + @staticmethod + def get_statistics_from_tensor(tensors): + """ + Calculates statistics data of tensor. + + Args: + tensors (numpy.ndarray): An numpy.ndarray of tensor data. + + Returns: + an instance of Statistics. + """ + ma_value = np.ma.masked_invalid(tensors) + total, valid = tensors.size, ma_value.count() + invalids = [] + for isfn in np.isnan, np.isposinf, np.isneginf: + if total - valid > sum(invalids): + count = np.count_nonzero(isfn(tensors)) + invalids.append(count) + else: + invalids.append(0) + + nan_count, pos_inf_count, neg_inf_count = invalids + if not valid: + logger.warning('There are no valid values in the tensors(size=%d, shape=%s)', total, tensors.shape) + statistics = Statistics(max_value=0, + min_value=0, + avg_value=0, + count=total, + nan_count=nan_count, + neg_inf_count=neg_inf_count, + pos_inf_count=pos_inf_count) + return statistics + + # BUG: max of a masked array with dtype np.float16 returns inf + # See numpy issue#15077 + if issubclass(tensors.dtype.type, np.floating): + tensor_min = ma_value.min(fill_value=np.PINF) + tensor_max = ma_value.max(fill_value=np.NINF) + if tensor_min < F32_MIN or tensor_max > F32_MAX: + logger.warning('Values(%f, %f) are too large, you may encounter some undefined ' + 'behaviours hereafter.', tensor_min, tensor_max) + else: + tensor_min = ma_value.min() + tensor_max = ma_value.max() + tensor_sum = ma_value.sum(dtype=np.float64) + statistics = Statistics(max_value=tensor_max, + min_value=tensor_min, + avg_value=tensor_sum / valid, + count=total, + nan_count=nan_count, + neg_inf_count=neg_inf_count, + pos_inf_count=pos_inf_count) + return statistics + + @staticmethod + def get_statistics_dict(stats): + """ + Get statistics dict according to statistics value. + + Args: + stats (Statistics): An instance of Statistics. + + Returns: + dict, a dict including 'max', 'min', 'avg', 'count', 'nan_count', 'neg_inf_count', 'pos_inf_count'. + """ + statistics = { + "max": float(stats.max), + "min": float(stats.min), + "avg": float(stats.avg), + "count": stats.count, + "nan_count": stats.nan_count, + "neg_inf_count": stats.neg_inf_count, + "pos_inf_count": stats.pos_inf_count + } + return statistics + + @staticmethod + def calc_diff_between_two_tensor(first_tensor, second_tensor, tolerance): + """ + Calculate the difference between the first tensor and the second tensor. + + Args: + first_tensor (numpy.ndarray): Specify the first tensor. + second_tensor (numpy.ndarray): Specify the second tensor. + tolerance (float): The tolerance of difference between the first tensor and the second tensor. + Its is a percentage. The boundary value is equal to max(abs(min),abs(max)) * tolerance. + The function of min and max is being used to calculate the min value and max value of + the result of the first tensor subtract the second tensor. If the absolute value of + result is less than or equal to boundary value, the result will set to be zero. + + Returns: + tuple[numpy.ndarray, OverallDiffMetric], numpy.ndarray indicates the value of the first tensor + subtract the second tensor and set the value to be zero when its less than or equal to tolerance. + + Raises: + ParamTypeError: If the type of these two tensors is not the numpy.ndarray. + ParamValueError: If the shape or dtype is not the same of these two tensors. + """ + if not isinstance(first_tensor, np.ndarray): + raise ParamTypeError('first_tensor', np.ndarray) + + if not isinstance(second_tensor, np.ndarray): + raise ParamTypeError('second_tensor', np.ndarray) + + if first_tensor.shape != second_tensor.shape: + raise ParamValueError("the shape: {} of first tensor is not equal to shape: {} of second tensor." + .format(first_tensor.shape, second_tensor.shape)) + + if first_tensor.dtype != second_tensor.dtype: + raise ParamValueError("the dtype: {} of first tensor is not equal to dtype: {} of second tensor." + .format(first_tensor.dtype, second_tensor.dtype)) + + diff_tensor = np.subtract(first_tensor, second_tensor) + stats = TensorUtils.get_statistics_from_tensor(diff_tensor) + boundary_value = max(abs(stats.max), abs(stats.min)) * tolerance + is_close = np.isclose(first_tensor, second_tensor, atol=boundary_value, rtol=0) + result = np.multiply(diff_tensor, ~is_close) + return result diff --git a/requirements.txt b/requirements.txt index 0ce0f827..bd08d2ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,5 @@ Werkzeug>=1.0.0 tabulate>=0.8.6 pandas>=1.0.4 yapf>=0.30.0 -treelib>=1.6.1 \ No newline at end of file +treelib>=1.6.1 +grpcio>=1.29.0 diff --git a/tests/ut/datavisual/data_transform/test_tensor_container.py b/tests/ut/datavisual/data_transform/test_tensor_container.py index edb8208f..7f44d0a9 100644 --- a/tests/ut/datavisual/data_transform/test_tensor_container.py +++ b/tests/ut/datavisual/data_transform/test_tensor_container.py @@ -14,9 +14,11 @@ # ============================================================================ """Test tensor container.""" import unittest.mock as mock + import numpy as np from mindinsight.datavisual.data_transform import tensor_container as tensor +from mindinsight.utils.tensor import TensorUtils class TestTensorContainer: @@ -34,8 +36,9 @@ class TestTensorContainer: def test_get_statistics_from_tensor(self): """Tests get statistics from tensor.""" - ndarray = np.array([1, 2, 3, 4, 5, float('-INF'), float('INF'), float('NAN')]).reshape([2, 2, 2]) - statistics = tensor.get_statistics_from_tensor(ndarray) + ndarray = np.array([1, 2, 3, 4, 5, float('-INF'), float('INF'), float('NAN')]).reshape( + [2, 2, 2]) + statistics = TensorUtils.get_statistics_from_tensor(ndarray) assert (statistics.max, statistics.min, statistics.avg, statistics.count, statistics.nan_count, statistics.neg_inf_count, statistics.pos_inf_count) == \ (5, 1, 3, 8, @@ -43,8 +46,9 @@ class TestTensorContainer: def test_calc_original_buckets(self): """Tests calculate original buckets.""" - ndarray = np.array([1, 2, 3, 4, 5, float('-INF'), float('INF'), float('NAN')]).reshape([2, 2, 2]) - statistics = tensor.get_statistics_from_tensor(ndarray) + ndarray = np.array([1, 2, 3, 4, 5, float('-INF'), float('INF'), float('NAN')]).reshape( + [2, 2, 2]) + statistics = TensorUtils.get_statistics_from_tensor(ndarray) buckets = tensor.calc_original_buckets(ndarray, statistics) assert (buckets[0].left, buckets[0].width, buckets[0].count) == (1, 2, 2) diff --git a/tests/ut/datavisual/processors/test_tensor_processor.py b/tests/ut/datavisual/processors/test_tensor_processor.py index a3cdcef1..d5be3371 100644 --- a/tests/ut/datavisual/processors/test_tensor_processor.py +++ b/tests/ut/datavisual/processors/test_tensor_processor.py @@ -29,10 +29,8 @@ from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.common.exceptions import TensorNotExistError from mindinsight.datavisual.data_transform import data_manager from mindinsight.datavisual.data_transform.tensor_container import calc_original_buckets -from mindinsight.datavisual.data_transform.tensor_container import get_statistics_from_tensor from mindinsight.datavisual.processors.tensor_processor import TensorProcessor -from mindinsight.datavisual.processors.tensor_processor import get_specific_dims_data -from mindinsight.datavisual.processors.tensor_processor import get_statistics_dict +from mindinsight.utils.tensor import TensorUtils from mindinsight.datavisual.utils import crc32 from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamMissError @@ -187,7 +185,7 @@ class TestTensorProcessor: dims = expected_values.get('value').get("dims") expected_data = np.array(expected_values.get('value').get("float_data")).reshape(dims) recv_tensor = np.array(recv_values.get('value').get("data")) - expected_tensor = get_specific_dims_data(expected_data, [0, 0, None, None], dims) + expected_tensor = TensorUtils.get_specific_dims_data(expected_data, [0, 0, None, None], dims) assert np.sum(np.isclose(recv_tensor, expected_tensor, rtol=1e-6) == 0) == 0 @pytest.mark.usefixtures('load_tensor_record') @@ -204,7 +202,8 @@ class TestTensorProcessor: assert recv_values.get('wall_time') == expected_values.get('wall_time') assert recv_values.get('step') == expected_values.get('step') expected_data = expected_values.get('value').get("float_data") - expected_statistic = get_statistics_dict(get_statistics_from_tensor(expected_data)) + expected_statistic_instance = TensorUtils.get_statistics_from_tensor(expected_data) + expected_statistic = TensorUtils.get_statistics_dict(expected_statistic_instance) recv_statistic = recv_values.get('value').get("statistics") assert recv_statistic.get("max") - expected_statistic.get("max") < 1e-6 assert recv_statistic.get("min") - expected_statistic.get("min") < 1e-6 @@ -225,7 +224,7 @@ class TestTensorProcessor: assert recv_values.get('wall_time') == expected_values.get('wall_time') assert recv_values.get('step') == expected_values.get('step') expected_data = expected_values.get('value').get("float_data") - expected_statistic = get_statistics_from_tensor(expected_data) + expected_statistic = TensorUtils.get_statistics_from_tensor(expected_data) expected_buckets = calc_original_buckets(expected_data, expected_statistic) recv_buckets = recv_values.get('value').get("histogram_buckets")