Browse Source

optimize tensor value connection method

tags/v1.1.0
yelihua 5 years ago
parent
commit
5368f2bfc9
3 changed files with 31 additions and 47 deletions
  1. +5
    -7
      mindinsight/debugger/debugger_grpc_server.py
  2. +16
    -10
      mindinsight/debugger/stream_cache/tensor.py
  3. +10
    -30
      mindinsight/debugger/stream_handler/tensor_handler.py

+ 5
- 7
mindinsight/debugger/debugger_grpc_server.py View File

@@ -14,7 +14,6 @@
# ============================================================================
"""Implement the debugger grpc server."""
import copy

from functools import wraps

import mindinsight
@@ -452,21 +451,20 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
def SendTensors(self, request_iterator, context):
"""Send tensors into DebuggerCache."""
log.info("Received tensor.")
tensor_construct = []
tensor_contents = []
tensor_stream = self._cache_store.get_stream_handler(Streams.TENSOR)
metadata_stream = self._cache_store.get_stream_handler(Streams.METADATA)
tensor_names = []
step = metadata_stream.step
for tensor in request_iterator:
tensor_construct.append(tensor)
tensor_contents.append(tensor.tensor_content)
if tensor.finished:
update_flag = tensor_stream.put({'step': step, 'tensor_protos': tensor_construct})
update_flag = tensor_stream.put(
{'step': step, 'tensor_proto': tensor, 'tensor_contents': tensor_contents})
if self._received_view_cmd.get('wait_for_tensor') and update_flag:
# update_flag is used to avoid querying empty tensors again
self._received_view_cmd['wait_for_tensor'] = False
log.debug("Set wait for tensor flag to False.")
tensor_construct = []
tensor_names.append(':'.join([tensor.node_name, tensor.slot]))
tensor_contents = []
continue
reply = get_ack_reply()
return reply


+ 16
- 10
mindinsight/debugger/stream_cache/tensor.py View File

@@ -17,11 +17,11 @@ from abc import abstractmethod, ABC

import numpy as np

from mindinsight.utils.tensor import TensorUtils
from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
from mindinsight.debugger.common.log import LOGGER as log
from mindinsight.debugger.common.utils import NUMPY_TYPE_MAP
from mindinsight.debugger.proto.ms_graph_pb2 import DataType
from mindinsight.utils.tensor import TensorUtils


class BaseTensor(ABC):
@@ -114,14 +114,21 @@ class BaseTensor(ABC):


class OpTensor(BaseTensor):
"""Tensor data structure for operator Node."""
"""
Tensor data structure for operator Node.

Args:
tensor_proto (TensorProto): Tensor proto contains tensor basic info.
tensor_content (byte): Tensor content value in byte format.
step (int): The step of the tensor.
"""
max_number_data_show_on_ui = 100000

def __init__(self, tensor_proto, step=0):
def __init__(self, tensor_proto, tensor_content, step=0):
# the type of tensor_proto is TensorProto
super(OpTensor, self).__init__(step)
self._tensor_proto = tensor_proto
self._value = self.generate_value_from_proto(tensor_proto)
self._value = self.to_numpy(tensor_content)
self._stats = None
self._tensor_comparison = None

@@ -169,21 +176,20 @@ class OpTensor(BaseTensor):
"""The property of tensor_comparison."""
return self._tensor_comparison

def generate_value_from_proto(self, tensor_proto):
def to_numpy(self, tensor_content):
"""
Generate tensor value from proto.
Construct tensor content from byte to numpy.

Args:
tensor_proto (TensorProto): The tensor proto.
tensor_content (byte): The tensor content.

Returns:
Union[None, np.ndarray], the value of the tensor.
"""
tensor_value = None
if tensor_proto.tensor_content:
tensor_value = tensor_proto.tensor_content
if tensor_content:
np_type = NUMPY_TYPE_MAP.get(self.dtype)
tensor_value = np.frombuffer(tensor_value, dtype=np_type)
tensor_value = np.frombuffer(tensor_content, dtype=np_type)
tensor_value = tensor_value.reshape(self.shape)
return tensor_value



+ 10
- 30
mindinsight/debugger/stream_handler/tensor_handler.py View File

@@ -59,46 +59,24 @@ class TensorHandler(StreamHandlerBase):
value (dict): The Tensor proto message.

- step (int): The current step of tensor.
- tensor_protos (list[TensorProto]): The tensor proto.
- tensor_proto (TensorProto): The tensor proto.
- tensor_contents (list[byte]): The list of tensor content values.

Returns:
bool, the tensor has updated successfully.
"""
tensor_protos = value.get('tensor_protos')
merged_tensor = self._get_merged_tensor(tensor_protos)
tensor_proto = value.get('tensor_proto')
tensor_proto.ClearField('tensor_content')
step = value.get('step', 0)
if merged_tensor.iter and step > 0:
if tensor_proto.iter and step > 0:
log.debug("Received previous tensor.")
step -= 1
tensor = OpTensor(merged_tensor, step)
tensor_content = b''.join(value.get('tensor_contents'))
tensor = OpTensor(tensor_proto, tensor_content, step)
flag = self._put_tensor_into_cache(tensor, step)
log.info("Put tensor %s of step: %d, into cache. Flag: %s", tensor.name, step, flag)
return flag

@staticmethod
def _get_merged_tensor(tensor_protos):
"""
Merged list of parsed tensor value into one.

Args:
tensor_protos (list[TensorProto]): List of tensor proto.

Returns:
TensorProto, merged tensor proto.
"""
merged_tensor = tensor_protos[-1]
if len(tensor_protos) > 1:
tensor_value = bytes()
for tensor_proto in tensor_protos:
if not tensor_proto.tensor_content:
log.warning("Doesn't find tensor value for %s:%s",
tensor_proto.node_name, tensor_proto.slot)
break
tensor_value += tensor_proto.tensor_content
merged_tensor.tensor_content = tensor_value
log.debug("Merge multi tensor values into one.")
return merged_tensor

def _put_tensor_into_cache(self, tensor, step):
"""
Put tensor into cache.
@@ -146,9 +124,11 @@ class TensorHandler(StreamHandlerBase):
continue
if DataType.Name(const_val.value.dtype) == "DT_TENSOR":
tensor_proto = const_val.value.tensor_val
tensor_value = tensor_proto.tensor_content
tensor_proto.ClearField('tensor_content')
tensor_proto.node_name = const_val.key
tensor_proto.slot = '0'
const_tensor = OpTensor(tensor_proto)
const_tensor = OpTensor(tensor_proto, tensor_value)
else:
const_tensor = ConstTensor(const_val)
self._const_vals[const_tensor.name] = const_tensor


Loading…
Cancel
Save