| @@ -171,7 +171,6 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| def _get_next_command(self): | |||
| """Get next command.""" | |||
| self._pos, event = self._cache_store.get_command(self._pos) | |||
| log.debug("Received event :%s", event) | |||
| if event is None: | |||
| return event | |||
| if isinstance(event, dict): | |||
| @@ -188,7 +187,7 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| """Deal with view cmd.""" | |||
| view_cmd = event.get('view_cmd') | |||
| node_name = event.get('node_name') | |||
| log.debug("Receive view cmd %s for node: %s.", view_cmd, node_name) | |||
| log.debug("Receive view cmd for node: %s.", node_name) | |||
| if not (view_cmd and node_name): | |||
| log.debug("Invalid view command. Ignore it.") | |||
| return None | |||
| @@ -265,10 +264,10 @@ class DebuggerGrpcServer(grpc_server_base.EventListenerServicer): | |||
| for tensor in request_iterator: | |||
| tensor_construct.append(tensor) | |||
| if tensor.finished: | |||
| if self._received_view_cmd.get('wait_for_tensor') and tensor.tensor_content: | |||
| update_flag = tensor_stream.put({'step': step, 'tensor_protos': tensor_construct}) | |||
| if self._received_view_cmd.get('wait_for_tensor') and update_flag: | |||
| self._received_view_cmd['wait_for_tensor'] = False | |||
| log.debug("Set wait for tensor flag to False.") | |||
| tensor_stream.put({'step': step, 'tensor_protos': tensor_construct}) | |||
| tensor_construct = [] | |||
| tensor_names.append(':'.join([tensor.node_name, tensor.slot])) | |||
| continue | |||
| @@ -50,8 +50,13 @@ class BaseTensor(ABC): | |||
| def value(self): | |||
| """The property of tensor shape.""" | |||
| @property | |||
| def empty(self): | |||
| """If the tensor value is valid.""" | |||
| return self.value is None | |||
| @abstractmethod | |||
| def get_tensor_value_by_shape(self, shape=None): | |||
| def get_tensor_serializable_value_by_shape(self, shape=None): | |||
| """Get tensor value by shape.""" | |||
| def _to_dict(self): | |||
| @@ -67,8 +72,9 @@ class BaseTensor(ABC): | |||
| def get_basic_info(self): | |||
| """Return basic info about tensor info.""" | |||
| tensor_value = self.value | |||
| if not self.shape: | |||
| value = self.value | |||
| value = tensor_value.tolist() if isinstance(tensor_value, np.ndarray) else tensor_value | |||
| else: | |||
| value = 'click to view' | |||
| res = self._to_dict() | |||
| @@ -91,7 +97,7 @@ class OpTensor(BaseTensor): | |||
| # the type of tensor_proto is TensorProto | |||
| super(OpTensor, self).__init__(step) | |||
| self._tensor_proto = tensor_proto | |||
| self._value = self.generate_value(tensor_proto) | |||
| self._value = self.generate_value_from_proto(tensor_proto) | |||
| @property | |||
| def name(self): | |||
| @@ -115,19 +121,18 @@ class OpTensor(BaseTensor): | |||
| @property | |||
| def value(self): | |||
| """The property of tensor value.""" | |||
| tensor_value = None | |||
| if self._value is not None: | |||
| tensor_value = self._value.tolist() | |||
| return self._value | |||
| return tensor_value | |||
| def generate_value_from_proto(self, tensor_proto): | |||
| """ | |||
| Generate tensor value from proto. | |||
| @property | |||
| def numpy_value(self): | |||
| """The property of tensor value in numpy type.""" | |||
| return self._value | |||
| Args: | |||
| tensor_proto (TensorProto): The tensor proto. | |||
| def generate_value(self, tensor_proto): | |||
| """Generate tensor value from proto.""" | |||
| Returns: | |||
| Union[None, np.ndarray], the value of the tensor. | |||
| """ | |||
| tensor_value = None | |||
| if tensor_proto.tensor_content: | |||
| tensor_value = tensor_proto.tensor_content | |||
| @@ -166,7 +171,7 @@ class OpTensor(BaseTensor): | |||
| shape (tuple): The specified shape. | |||
| Returns: | |||
| Union[None, str, numpy.ndarray], the sub-tensor. | |||
| Union[None, str, numpy.ndarray], the value of parsed tensor. | |||
| """ | |||
| if self._value is None: | |||
| log.warning("%s has no value yet.", self.name) | |||
| @@ -199,6 +204,7 @@ class ConstTensor(BaseTensor): | |||
| # the type of const_proto is NamedValueProto | |||
| super(ConstTensor, self).__init__() | |||
| self._const_proto = const_proto | |||
| self._value = self.generate_value_from_proto(const_proto) | |||
| def set_step(self, step): | |||
| """Set step value.""" | |||
| @@ -222,16 +228,29 @@ class ConstTensor(BaseTensor): | |||
| @property | |||
| def value(self): | |||
| """The property of tensor shape.""" | |||
| fields = self._const_proto.value.ListFields() | |||
| return self._value | |||
| @staticmethod | |||
| def generate_value_from_proto(tensor_proto): | |||
| """ | |||
| Generate tensor value from proto. | |||
| Args: | |||
| tensor_proto (TensorProto): The tensor proto. | |||
| Returns: | |||
| Union[None, np.ndarray], the value of the tensor. | |||
| """ | |||
| fields = tensor_proto.value.ListFields() | |||
| if len(fields) != 2: | |||
| log.warning("Unexpected const proto <%s>.\n Please check offline.", self._const_proto) | |||
| log.warning("Unexpected const proto <%s>.\n Please check offline.", tensor_proto) | |||
| for field_name, field_value in fields: | |||
| if field_name != 'dtype': | |||
| return field_value | |||
| return None | |||
| def get_tensor_value_by_shape(self, shape=None): | |||
| def get_tensor_serializable_value_by_shape(self, shape=None): | |||
| """Get tensor info with value.""" | |||
| if shape is not None: | |||
| log.warning("Invalid shape for const value.") | |||
| return self.value | |||
| return self._value | |||
| @@ -42,6 +42,9 @@ class TensorHandler(StreamHandlerBase): | |||
| - step (int): The current step of tensor. | |||
| - tensor_protos (list[TensorProto]): The tensor proto. | |||
| Returns: | |||
| bool, the tensor has updated successfully. | |||
| """ | |||
| tensor_protos = value.get('tensor_protos') | |||
| merged_tensor = self._get_merged_tensor(tensor_protos) | |||
| @@ -50,8 +53,9 @@ class TensorHandler(StreamHandlerBase): | |||
| log.debug("Received previous tensor.") | |||
| step -= 1 | |||
| tensor = OpTensor(merged_tensor, step) | |||
| self._put_tensor_into_cache(tensor, step) | |||
| log.info("Put tensor %s of step: %d, into cache", tensor.name, step) | |||
| flag = self._put_tensor_into_cache(tensor, step) | |||
| log.info("Put tensor %s of step: %d, into cache. Flag: %s", tensor.name, step, flag) | |||
| return flag | |||
| @staticmethod | |||
| def _get_merged_tensor(tensor_protos): | |||
| @@ -83,12 +87,34 @@ class TensorHandler(StreamHandlerBase): | |||
| Args: | |||
| tensor (OpTensor): The tensor value. | |||
| step (int): The step of tensor. | |||
| Returns: | |||
| bool, the tensor has updated successfully. | |||
| """ | |||
| cache_tensor = self._tensors.get(tensor.name) | |||
| if cache_tensor is None: | |||
| cache_tensor = {} | |||
| self._tensors[tensor.name] = cache_tensor | |||
| old_tensor = cache_tensor.get(step) | |||
| if old_tensor and not self.is_value_diff(old_tensor.value, tensor.value): | |||
| log.debug("Tensor %s of step %s has no change. Ignore it.") | |||
| return False | |||
| cache_tensor[step] = tensor | |||
| log.debug("Put updated tensor value for %s of step %s.", tensor.name, step) | |||
| return True | |||
| @staticmethod | |||
| def is_value_diff(old_value, new_value): | |||
| """Check tensor value if there are equal.""" | |||
| log.debug("old value type: %s, new_value type: %s", type(old_value), type(new_value)) | |||
| if old_value is None and new_value is None: | |||
| return False | |||
| flag = old_value != new_value | |||
| if isinstance(flag, np.ndarray): | |||
| return flag.any() | |||
| return flag | |||
| def put_const_vals(self, const_vals): | |||
| """ | |||
| @@ -224,7 +250,7 @@ class TensorHandler(StreamHandlerBase): | |||
| if prev_step < 0: | |||
| return flag | |||
| tensor = self._get_tensor(tensor_name, step=prev_step) | |||
| return bool(tensor and tensor.value) | |||
| return bool(tensor and not tensor.empty) | |||
| def get_tensor_value_by_name(self, tensor_name, prev=False): | |||
| """Get tensor value by name in numpy type.""" | |||
| @@ -249,7 +275,6 @@ class TensorHandler(StreamHandlerBase): | |||
| expired_tensor.append(tensor_name) | |||
| for tensor_name in expired_tensor: | |||
| self._tensors.pop(tensor_name) | |||
| self._tensors = {} | |||
| def get_tensors_diff(self, tensor_name, shape, tolerance=0): | |||
| """ | |||