Browse Source

fix the bug for updating tensor history

tags/v1.0.0
yelihua 5 years ago
parent
commit
8db1aec6c3
3 changed files with 69 additions and 26 deletions
  1. +3
    -4
      mindinsight/debugger/debugger_grpc_server.py
  2. +37
    -18
      mindinsight/debugger/stream_cache/tensor.py
  3. +29
    -4
      mindinsight/debugger/stream_handler/tensor_handler.py

+ 3
- 4
mindinsight/debugger/debugger_grpc_server.py View File

@@ -171,7 +171,6 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
def _get_next_command(self):
"""Get next command."""
self._pos, event = self._cache_store.get_command(self._pos)
log.debug("Received event :%s", event)
if event is None:
return event
if isinstance(event, dict):
@@ -188,7 +187,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
"""Deal with view cmd."""
view_cmd = event.get('view_cmd')
node_name = event.get('node_name')
log.debug("Receive view cmd %s for node: %s.", view_cmd, node_name)
log.debug("Receive view cmd for node: %s.", node_name)
if not (view_cmd and node_name):
log.debug("Invalid view command. Ignore it.")
return None
@@ -265,10 +264,10 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer):
for tensor in request_iterator:
tensor_construct.append(tensor)
if tensor.finished:
if self._received_view_cmd.get('wait_for_tensor') and tensor.tensor_content:
update_flag = tensor_stream.put({'step': step, 'tensor_protos': tensor_construct})
if self._received_view_cmd.get('wait_for_tensor') and update_flag:
self._received_view_cmd['wait_for_tensor'] = False
log.debug("Set wait for tensor flag to False.")
tensor_stream.put({'step': step, 'tensor_protos': tensor_construct})
tensor_construct = []
tensor_names.append(':'.join([tensor.node_name, tensor.slot]))
continue


+ 37
- 18
mindinsight/debugger/stream_cache/tensor.py View File

@@ -50,8 +50,13 @@ class BaseTensor(ABC):
def value(self):
"""The property of tensor shape."""

@property
def empty(self):
"""If the tensor value is valid."""
return self.value is None

@abstractmethod
def get_tensor_value_by_shape(self, shape=None):
def get_tensor_serializable_value_by_shape(self, shape=None):
"""Get tensor value by shape."""

def _to_dict(self):
@@ -67,8 +72,9 @@ class BaseTensor(ABC):

def get_basic_info(self):
"""Return basic info about tensor info."""
tensor_value = self.value
if not self.shape:
value = self.value
value = tensor_value.tolist() if isinstance(tensor_value, np.ndarray) else tensor_value
else:
value = 'click to view'
res = self._to_dict()
@@ -91,7 +97,7 @@ class OpTensor(BaseTensor):
# the type of tensor_proto is TensorProto
super(OpTensor, self).__init__(step)
self._tensor_proto = tensor_proto
self._value = self.generate_value(tensor_proto)
self._value = self.generate_value_from_proto(tensor_proto)

@property
def name(self):
@@ -115,19 +121,18 @@ class OpTensor(BaseTensor):
@property
def value(self):
"""The property of tensor value."""
tensor_value = None
if self._value is not None:
tensor_value = self._value.tolist()
return self._value

return tensor_value
def generate_value_from_proto(self, tensor_proto):
"""
Generate tensor value from proto.

@property
def numpy_value(self):
"""The property of tensor value in numpy type."""
return self._value
Args:
tensor_proto (TensorProto): The tensor proto.

def generate_value(self, tensor_proto):
"""Generate tensor value from proto."""
Returns:
Union[None, np.ndarray], the value of the tensor.
"""
tensor_value = None
if tensor_proto.tensor_content:
tensor_value = tensor_proto.tensor_content
@@ -166,7 +171,7 @@ class OpTensor(BaseTensor):
shape (tuple): The specified shape.

Returns:
Union[None, str, numpy.ndarray], the sub-tensor.
Union[None, str, numpy.ndarray], the value of parsed tensor.
"""
if self._value is None:
log.warning("%s has no value yet.", self.name)
@@ -199,6 +204,7 @@ class ConstTensor(BaseTensor):
# the type of const_proto is NamedValueProto
super(ConstTensor, self).__init__()
self._const_proto = const_proto
self._value = self.generate_value_from_proto(const_proto)

def set_step(self, step):
"""Set step value."""
@@ -222,16 +228,29 @@ class ConstTensor(BaseTensor):
@property
def value(self):
"""The property of tensor shape."""
fields = self._const_proto.value.ListFields()
return self._value

@staticmethod
def generate_value_from_proto(tensor_proto):
"""
Generate tensor value from proto.

Args:
tensor_proto (TensorProto): The tensor proto.

Returns:
Union[None, np.ndarray], the value of the tensor.
"""
fields = tensor_proto.value.ListFields()
if len(fields) != 2:
log.warning("Unexpected const proto <%s>.\n Please check offline.", self._const_proto)
log.warning("Unexpected const proto <%s>.\n Please check offline.", tensor_proto)
for field_name, field_value in fields:
if field_name != 'dtype':
return field_value
return None

def get_tensor_value_by_shape(self, shape=None):
def get_tensor_serializable_value_by_shape(self, shape=None):
"""Get tensor info with value."""
if shape is not None:
log.warning("Invalid shape for const value.")
return self.value
return self._value

+ 29
- 4
mindinsight/debugger/stream_handler/tensor_handler.py View File

@@ -42,6 +42,9 @@ class TensorHandler(StreamHandlerBase):
- step (int): The current step of tensor.

- tensor_protos (list[TensorProto]): The tensor proto.

Returns:
bool, the tensor has updated successfully.
"""
tensor_protos = value.get('tensor_protos')
merged_tensor = self._get_merged_tensor(tensor_protos)
@@ -50,8 +53,9 @@ class TensorHandler(StreamHandlerBase):
log.debug("Received previous tensor.")
step -= 1
tensor = OpTensor(merged_tensor, step)
self._put_tensor_into_cache(tensor, step)
log.info("Put tensor %s of step: %d, into cache", tensor.name, step)
flag = self._put_tensor_into_cache(tensor, step)
log.info("Put tensor %s of step: %d, into cache. Flag: %s", tensor.name, step, flag)
return flag

@staticmethod
def _get_merged_tensor(tensor_protos):
@@ -83,12 +87,34 @@ class TensorHandler(StreamHandlerBase):

Args:
tensor (OpTensor): The tensor value.
step (int): The step of tensor.

Returns:
bool, the tensor has updated successfully.
"""
cache_tensor = self._tensors.get(tensor.name)
if cache_tensor is None:
cache_tensor = {}
self._tensors[tensor.name] = cache_tensor

old_tensor = cache_tensor.get(step)
if old_tensor and not self.is_value_diff(old_tensor.value, tensor.value):
log.debug("Tensor %s of step %s has no change. Ignore it.")
return False
cache_tensor[step] = tensor
log.debug("Put updated tensor value for %s of step %s.", tensor.name, step)
return True

@staticmethod
def is_value_diff(old_value, new_value):
"""Check tensor value if there are equal."""
log.debug("old value type: %s, new_value type: %s", type(old_value), type(new_value))
if old_value is None and new_value is None:
return False
flag = old_value != new_value
if isinstance(flag, np.ndarray):
return flag.any()
return flag

def put_const_vals(self, const_vals):
"""
@@ -224,7 +250,7 @@ class TensorHandler(StreamHandlerBase):
if prev_step < 0:
return flag
tensor = self._get_tensor(tensor_name, step=prev_step)
return bool(tensor and tensor.value)
return bool(tensor and not tensor.empty)

def get_tensor_value_by_name(self, tensor_name, prev=False):
"""Get tensor value by name in numpy type."""
@@ -249,7 +275,6 @@ class TensorHandler(StreamHandlerBase):
expired_tensor.append(tensor_name)
for tensor_name in expired_tensor:
self._tensors.pop(tensor_name)
self._tensors = {}

def get_tensors_diff(self, tensor_name, shape, tolerance=0):
"""


Loading…
Cancel
Save