| @@ -14,7 +14,6 @@ | |||
| # ============================================================================ | |||
| """Implement the debugger grpc server.""" | |||
| import copy | |||
| from functools import wraps | |||
| import mindinsight | |||
| @@ -452,21 +451,20 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| def SendTensors(self, request_iterator, context): | |||
| """Send tensors into DebuggerCache.""" | |||
| log.info("Received tensor.") | |||
| tensor_construct = [] | |||
| tensor_contents = [] | |||
| tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR) | |||
| metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA) | |||
| tensor_names = [] | |||
| step = metadata_stream.step | |||
| for tensor in request_iterator: | |||
| tensor_construct.append(tensor) | |||
| tensor_contents.append(tensor.tensor_content) | |||
| if tensor.finished: | |||
| update_flag = tensor_stream.put({'step': step, 'tensor_protos': tensor_construct}) | |||
| update_flag = tensor_stream.put( | |||
| {'step': step, 'tensor_proto': tensor, 'tensor_contents': tensor_contents}) | |||
| if self._received_view_cmd.get('wait_for_tensor') and update_flag: | |||
| # update_flag is used to avoid querying empty tensors again | |||
| self._received_view_cmd['wait_for_tensor'] = False | |||
| log.debug("Set wait for tensor flag to False.") | |||
| tensor_construct = [] | |||
| tensor_names.append(':'.join([tensor.node_name, tensor.slot])) | |||
| tensor_contents = [] | |||
| continue | |||
| reply = get_ack_reply() | |||
| return reply | |||
| @@ -17,11 +17,11 @@ from abc import abstractmethod, ABC | |||
| import numpy as np | |||
| from mindinsight.utils.tensor import TensorUtils | |||
| from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError | |||
| from mindinsight.debugger.common.log import LOGGER as log | |||
| from mindinsight.debugger.common.utils import NUMPY_TYPE_MAP | |||
| from mindinsight.debugger.proto.ms_graph_pb2 import DataType | |||
| from mindinsight.utils.tensor import TensorUtils | |||
| class BaseTensor(ABC): | |||
| @@ -114,14 +114,21 @@ class BaseTensor(ABC): | |||
| class OpTensor(BaseTensor): | |||
| """Tensor data structure for operator Node.""" | |||
| """ | |||
| Tensor data structure for operator Node. | |||
| Args: | |||
| tensor_proto (TensorProto): Tensor proto contains tensor basic info. | |||
| tensor_content (byte): Tensor content value in byte format. | |||
| step (int): The step of the tensor. | |||
| """ | |||
| max_number_data_show_on_ui = 100000 | |||
| def __init__(self, tensor_proto, step=0): | |||
| def __init__(self, tensor_proto, tensor_content, step=0): | |||
| # the type of tensor_proto is TensorProto | |||
| super(OpTensor, self).__init__(step) | |||
| self._tensor_proto = tensor_proto | |||
| self._value = self.generate_value_from_proto(tensor_proto) | |||
| self._value = self.to_numpy(tensor_content) | |||
| self._stats = None | |||
| self._tensor_comparison = None | |||
| @@ -169,21 +176,20 @@ class OpTensor(BaseTensor): | |||
| """The property of tensor_comparison.""" | |||
| return self._tensor_comparison | |||
| def generate_value_from_proto(self, tensor_proto): | |||
| def to_numpy(self, tensor_content): | |||
| """ | |||
| Generate tensor value from proto. | |||
| Construct tensor content from byte to numpy. | |||
| Args: | |||
| tensor_proto (TensorProto): The tensor proto. | |||
| tensor_content (byte): The tensor content. | |||
| Returns: | |||
| Union[None, np.ndarray], the value of the tensor. | |||
| """ | |||
| tensor_value = None | |||
| if tensor_proto.tensor_content: | |||
| tensor_value = tensor_proto.tensor_content | |||
| if tensor_content: | |||
| np_type = NUMPY_TYPE_MAP.get(self.dtype) | |||
| tensor_value = np.frombuffer(tensor_value, dtype=np_type) | |||
| tensor_value = np.frombuffer(tensor_content, dtype=np_type) | |||
| tensor_value = tensor_value.reshape(self.shape) | |||
| return tensor_value | |||
| @@ -59,46 +59,24 @@ class TensorHandler(StreamHandlerBase): | |||
| value (dict): The Tensor proto message. | |||
| - step (int): The current step of tensor. | |||
| - tensor_protos (list[TensorProto]): The tensor proto. | |||
| - tensor_proto (TensorProto): The tensor proto. | |||
| - tensor_contents (list[byte]): The list of tensor content values. | |||
| Returns: | |||
| bool, the tensor has updated successfully. | |||
| """ | |||
| tensor_protos = value.get('tensor_protos') | |||
| merged_tensor = self._get_merged_tensor(tensor_protos) | |||
| tensor_proto = value.get('tensor_proto') | |||
| tensor_proto.ClearField('tensor_content') | |||
| step = value.get('step', 0) | |||
| if merged_tensor.iter and step > 0: | |||
| if tensor_proto.iter and step > 0: | |||
| log.debug("Received previous tensor.") | |||
| step -= 1 | |||
| tensor = OpTensor(merged_tensor, step) | |||
| tensor_content = b''.join(value.get('tensor_contents')) | |||
| tensor = OpTensor(tensor_proto, tensor_content, step) | |||
| flag = self._put_tensor_into_cache(tensor, step) | |||
| log.info("Put tensor %s of step: %d, into cache. Flag: %s", tensor.name, step, flag) | |||
| return flag | |||
| @staticmethod | |||
| def _get_merged_tensor(tensor_protos): | |||
| """ | |||
| Merged list of parsed tensor value into one. | |||
| Args: | |||
| tensor_protos (list[TensorProto]): List of tensor proto. | |||
| Returns: | |||
| TensorProto, merged tensor proto. | |||
| """ | |||
| merged_tensor = tensor_protos[-1] | |||
| if len(tensor_protos) > 1: | |||
| tensor_value = bytes() | |||
| for tensor_proto in tensor_protos: | |||
| if not tensor_proto.tensor_content: | |||
| log.warning("Doesn't find tensor value for %s:%s", | |||
| tensor_proto.node_name, tensor_proto.slot) | |||
| break | |||
| tensor_value += tensor_proto.tensor_content | |||
| merged_tensor.tensor_content = tensor_value | |||
| log.debug("Merge multi tensor values into one.") | |||
| return merged_tensor | |||
| def _put_tensor_into_cache(self, tensor, step): | |||
| """ | |||
| Put tensor into cache. | |||
| @@ -146,9 +124,11 @@ class TensorHandler(StreamHandlerBase): | |||
| continue | |||
| if DataType.Name(const_val.value.dtype) == "DT_TENSOR": | |||
| tensor_proto = const_val.value.tensor_val | |||
| tensor_value = tensor_proto.tensor_content | |||
| tensor_proto.ClearField('tensor_content') | |||
| tensor_proto.node_name = const_val.key | |||
| tensor_proto.slot = '0' | |||
| const_tensor = OpTensor(tensor_proto) | |||
| const_tensor = OpTensor(tensor_proto, tensor_value) | |||
| else: | |||
| const_tensor = ConstTensor(const_val) | |||
| self._const_vals[const_tensor.name] = const_tensor | |||