Browse Source

support block query for tensor visualization and tensor comparisons.

tags/v1.0.0
wangshuide2020 5 years ago
parent
commit
a194bb1660
7 changed files with 183 additions and 141 deletions
  1. +8
    -8
      mindinsight/datavisual/processors/tensor_processor.py
  2. +0
    -24
      mindinsight/debugger/common/utils.py
  3. +6
    -21
      mindinsight/debugger/debugger_server.py
  4. +36
    -2
      mindinsight/debugger/stream_cache/tensor.py
  5. +27
    -8
      mindinsight/debugger/stream_handler/tensor_handler.py
  6. +96
    -71
      mindinsight/utils/tensor.py
  7. +10
    -7
      tests/ut/datavisual/processors/test_tensor_processor.py

+ 8
- 8
mindinsight/datavisual/processors/tensor_processor.py View File

@@ -19,7 +19,7 @@ import numpy as np

from mindinsight.datavisual.utils.tools import to_int
from mindinsight.utils.exceptions import ParamValueError, UrlDecodeError
from mindinsight.utils.tensor import TensorUtils
from mindinsight.utils.tensor import TensorUtils, MAX_DIMENSIONS_FOR_TENSOR
from mindinsight.conf.constants import MAX_TENSOR_RESPONSE_DATA_SIZE
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.common.exceptions import StepTensorDataNotInCacheError, TensorNotExistError
@@ -49,7 +49,6 @@ class TensorProcessor(BaseProcessor):
UrlDecodeError, If unquote train id error with strict mode.
"""
Validation.check_param_empty(train_id=train_ids, tag=tags)
TensorUtils.validate_dims_format(dims)

for index, train_id in enumerate(train_ids):
try:
@@ -99,6 +98,8 @@ class TensorProcessor(BaseProcessor):
values = self._get_tensors_summary(detail, tensors)
elif detail == 'data':
Validation.check_param_empty(step=step, dims=dims)
# Limit to query max two dimensions for tensor in table view.
dims = TensorUtils.parse_shape(dims, limit=MAX_DIMENSIONS_FOR_TENSOR)
step = to_int(step, "step")
values = self._get_tensors_data(step, dims, tensors)
elif detail == 'histogram':
@@ -152,7 +153,7 @@ class TensorProcessor(BaseProcessor):
"data_type": anf_ir_pb2.DataType.Name(value.data_type)
}
if detail and detail == 'stats':
stats = TensorUtils.get_statistics_dict(value.stats)
stats = TensorUtils.get_statistics_dict(stats=value.stats, overall_stats=value.stats)
value_dict.update({"statistics": stats})

values.append({
@@ -169,7 +170,7 @@ class TensorProcessor(BaseProcessor):

Args:
step (int): Specify step of tensor.
dims (str): Specify dims of tensor.
dims (tuple): Specify dims of tensor.
tensors (list): The list of _Tensor data.

Returns:
@@ -199,14 +200,13 @@ class TensorProcessor(BaseProcessor):
"""
values = []
step_in_cache = False
dims = TensorUtils.convert_array_from_str_dims(dims, limit=2)
for tensor in tensors:
# This value is an instance of TensorContainer
value = tensor.value
if step != tensor.step:
continue
step_in_cache = True
res_data = TensorUtils.get_specific_dims_data(value.ndarray, dims, list(value.dims))
res_data = TensorUtils.get_specific_dims_data(value.ndarray, dims)
flatten_data = res_data.flatten().tolist()
if len(flatten_data) > MAX_TENSOR_RESPONSE_DATA_SIZE:
raise ResponseDataExceedMaxValueError("the size of response data: {} exceed max value: {}."
@@ -244,7 +244,7 @@ class TensorProcessor(BaseProcessor):
"dims": value.dims,
"data_type": anf_ir_pb2.DataType.Name(value.data_type),
"data": tensor_data,
"statistics": TensorUtils.get_statistics_dict(stats)
"statistics": TensorUtils.get_statistics_dict(stats=stats, overall_stats=value.stats)
}
})
break
@@ -293,7 +293,7 @@ class TensorProcessor(BaseProcessor):
"dims": value.dims,
"data_type": anf_ir_pb2.DataType.Name(value.data_type),
"histogram_buckets": buckets,
"statistics": TensorUtils.get_statistics_dict(value.stats)
"statistics": TensorUtils.get_statistics_dict(stats=value.stats, overall_stats=value.stats)
}
})



+ 0
- 24
mindinsight/debugger/common/utils.py View File

@@ -18,8 +18,6 @@ from collections import namedtuple

import numpy as np

from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
from mindinsight.debugger.common.log import logger as log
from mindinsight.debugger.proto.debug_grpc_pb2 import EventReply
from mindinsight.debugger.stream_cache.debugger_graph import NodeTypeEnum

@@ -144,25 +142,3 @@ def is_scope_type(node_type):
"""Judge whether the type is scope type."""
scope_types = [NodeTypeEnum.NAME_SCOPE.value, NodeTypeEnum.AGGREGATION_SCOPE.value]
return node_type in scope_types


def str_to_slice_or_int(input_str):
"""
Translate param from string to slice or int.

Args:
input_str (str): The string to be translated.

Returns:
Union[int, slice], the transformed param.
"""
try:
if ':' in input_str:
ret = slice(*map(lambda x: int(x.strip()) if x.strip() else None, input_str.split(':')))
else:
ret = int(input_str)
except ValueError as err:
log.error("Failed to create slice from %s", input_str)
log.exception(err)
raise DebuggerParamValueError("Invalid shape.")
return ret

+ 6
- 21
mindinsight/debugger/debugger_server.py View File

@@ -27,13 +27,13 @@ from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValue
DebuggerDeleteWatchPointError, DebuggerContinueError, DebuggerPauseError, DebuggerCompareTensorError
from mindinsight.debugger.common.log import logger as log
from mindinsight.debugger.common.utils import get_ack_reply, ServerStatus, \
create_view_event_from_tensor_history, Streams, is_scope_type, NodeBasicInfo, \
str_to_slice_or_int
create_view_event_from_tensor_history, Streams, is_scope_type, NodeBasicInfo
from mindinsight.debugger.debugger_cache import DebuggerCache
from mindinsight.debugger.debugger_grpc_server import DebuggerGrpcServer
from mindinsight.debugger.proto import debug_grpc_pb2_grpc as grpc_server_base
from mindinsight.debugger.proto.debug_grpc_pb2 import RunCMD
from mindinsight.utils.exceptions import MindInsightException
from mindinsight.utils.tensor import TensorUtils, MAX_DIMENSIONS_FOR_TENSOR


class DebuggerServer:
@@ -128,7 +128,8 @@ class DebuggerServer:
"Failed to compare tensors as the MindSpore is not in waiting state."
)
self.validate_tensor_param(name, detail)
parsed_shape = self.parse_shape(shape)
# Limit to query max two dimensions for tensor in table view.
parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR)
node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name)
tolerance = to_float(tolerance, 'tolerance')
tensor_stream = self.cache_store.get_stream_handler(Streams.TENSOR)
@@ -303,7 +304,8 @@ class DebuggerServer:
"""Retrieve the tensor value."""
log.info("Retrieve tensor value: name: %s, detail: %s, shape: %s", name, detail, shape)
self.validate_tensor_param(name, detail)
parsed_shape = self.parse_shape(shape)
# Limit to query max two dimensions for tensor in table view.
parsed_shape = TensorUtils.parse_shape(shape, limit=MAX_DIMENSIONS_FOR_TENSOR)
node_type, tensor_name = self._get_tensor_name_and_type_by_ui_name(name)
reply = self.cache_store.get_stream_handler(Streams.TENSOR).get(
{'name': tensor_name,
@@ -344,23 +346,6 @@ class DebuggerServer:
log.error("Invalid detail value. Received: %s", detail)
raise DebuggerParamValueError("Invalid detail value.")

@staticmethod
def parse_shape(shape):
"""Parse shape."""
if shape is None:
return shape
if not (isinstance(shape, str) and shape.startswith('[') and shape.endswith(']')):
log.error("Invalid shape. Received: %s", shape)
raise DebuggerParamValueError("Invalid shape.")
shape = shape.strip('[]')
if shape.count(':') > 2:
log.error("Invalid shape. At most two dimensions are specified.")
raise DebuggerParamValueError("Invalid shape.")
parsed_shape = tuple(
str_to_slice_or_int(dim) for dim in shape.split(',')) if shape else tuple()
log.info("Parsed shape: %s from %s", parsed_shape, shape)
return parsed_shape

def _retrieve_watchpoint(self, filter_condition):
"""
Retrieve watchpoint.


+ 36
- 2
mindinsight/debugger/stream_cache/tensor.py View File

@@ -98,6 +98,8 @@ class OpTensor(BaseTensor):
super(OpTensor, self).__init__(step)
self._tensor_proto = tensor_proto
self._value = self.generate_value_from_proto(tensor_proto)
self._stats = None
self._tensor_comparison = None

@property
def name(self):
@@ -123,6 +125,16 @@ class OpTensor(BaseTensor):
"""The property of tensor value."""
return self._value

@property
def stats(self):
"""The property of tensor stats."""
return self._stats

@property
def tensor_comparison(self):
"""The property of tensor_comparison."""
return self._tensor_comparison

def generate_value_from_proto(self, tensor_proto):
"""
Generate tensor value from proto.
@@ -156,13 +168,35 @@ class OpTensor(BaseTensor):
# the type of tensor_value is one of None, np.ndarray or str
if isinstance(tensor_value, np.ndarray):
statistics = TensorUtils.get_statistics_from_tensor(tensor_value)
res['statistics'] = TensorUtils.get_statistics_dict(statistics)
if not self.stats:
self.update_tensor_stats(TensorUtils.get_statistics_from_tensor(self.value))
res['statistics'] = TensorUtils.get_statistics_dict(stats=statistics, overall_stats=self.stats)
res['value'] = tensor_value.tolist()
elif isinstance(tensor_value, str):
res['value'] = tensor_value

return res

def update_tensor_comparisons(self, tensor_comparison):
"""
Update tensor comparison for tensor.

Args:
tensor_comparison (TensorComparison) instance of TensorComparison.

"""
self._tensor_comparison = tensor_comparison

def update_tensor_stats(self, stats):
"""
Update tensor stats.

Args:
stats (Statistics) instance of Statistics.

"""
self._stats = stats

def get_tensor_value_by_shape(self, shape=None):
"""
Get tensor value by shape.
@@ -190,8 +224,8 @@ class OpTensor(BaseTensor):
raise DebuggerParamValueError("Invalid shape. Shape unmatched.")
if isinstance(value, np.ndarray):
if value.size > self.max_number_data_show_on_ui:
log.info("The tensor size is %d, which is too large to show on UI.", value.size)
value = "Too large to show."
log.info("The tensor size is %s, which is too large to show on UI.")
else:
value = np.asarray(value)
return value


+ 27
- 8
mindinsight/debugger/stream_handler/tensor_handler.py View File

@@ -21,7 +21,7 @@ from mindinsight.debugger.common.log import logger as log
from mindinsight.debugger.proto.ms_graph_pb2 import DataType
from mindinsight.debugger.stream_cache.tensor import OpTensor, ConstTensor
from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
from mindinsight.utils.tensor import TensorUtils
from mindinsight.utils.tensor import TensorUtils, TensorComparison


class TensorHandler(StreamHandlerBase):
@@ -291,7 +291,8 @@ class TensorHandler(StreamHandlerBase):
boundary value, the result will set to be zero.

Raises:
DebuggerParamValueError, If get current step node and previous step node failed.
DebuggerParamValueError, If get current step node and previous step node failed or
the type of tensor value is not numpy.ndarray."

Returns:
dict, the retrieved data.
@@ -306,15 +307,33 @@ class TensorHandler(StreamHandlerBase):
prev_tensor_slice = prev_tensor.get_tensor_value_by_shape(shape)
tensor_info = curr_tensor.get_basic_info()
if isinstance(tensor_info, dict):
del tensor_info['has_prev_step']
del tensor_info['value']
tensor_info.pop('has_prev_step')
tensor_info.pop('value')

tensor_comparison = curr_tensor.tensor_comparison
if not tensor_comparison or tensor_comparison.tolerance != tolerance:
if isinstance(curr_tensor.value, np.ndarray) and isinstance(prev_tensor.value, np.ndarray):
tensor_diff = TensorUtils.calc_diff_between_two_tensor(curr_tensor.value, prev_tensor.value, tolerance)
if not tensor_comparison:
stats = TensorUtils.get_statistics_from_tensor(tensor_diff)
tensor_comparison = TensorComparison(tolerance, stats, tensor_diff)
curr_tensor.update_tensor_comparisons(tensor_comparison)
else:
tensor_comparison.update(tolerance=tolerance, value=tensor_diff)
else:
raise DebuggerParamValueError("The type of tensor value should be numpy.ndarray.")

# the type of curr_tensor_slice is one of None, np.ndarray or str
if isinstance(curr_tensor_slice, np.ndarray) and isinstance(prev_tensor_slice, np.ndarray):
diff_tensor = TensorUtils.calc_diff_between_two_tensor(curr_tensor_slice, prev_tensor_slice, tolerance)
result = np.stack([prev_tensor_slice, curr_tensor_slice, diff_tensor], axis=-1)
if not shape:
tensor_diff_slice = tensor_comparison.value
else:
tensor_diff_slice = tensor_comparison.value[shape]
result = np.stack([prev_tensor_slice, curr_tensor_slice, tensor_diff_slice], axis=-1)
tensor_info['diff'] = result.tolist()
stats = TensorUtils.get_statistics_from_tensor(diff_tensor)
tensor_info['statistics'] = TensorUtils.get_statistics_dict(stats)
stats = TensorUtils.get_statistics_from_tensor(tensor_diff_slice)
tensor_info['statistics'] = TensorUtils.get_statistics_dict(stats=stats,
overall_stats=tensor_comparison.stats)
elif isinstance(curr_tensor_slice, str):
tensor_info['diff'] = curr_tensor_slice
reply = {'tensor_value': tensor_info}


+ 96
- 71
mindinsight/utils/tensor.py View File

@@ -16,13 +16,12 @@

import numpy as np

from mindinsight.datavisual.utils.tools import to_int
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.utils.exceptions import ParamTypeError
from mindinsight.utils.log import utils_logger as logger

F32_MIN, F32_MAX = np.finfo(np.float32).min, np.finfo(np.float32).max
MAX_DIMENSIONS_FOR_TENSOR = 2

class Statistics:
"""Statistics data class.
@@ -82,97 +81,115 @@ class Statistics:
"""Get count of positive INF."""
return self._pos_inf_count

class TensorComparison:
"""TensorComparison class.

class TensorUtils:
"""Tensor Utils class."""
Args:
tolerance (float): tolerance for calculating tensor diff.
stats (float): statistics of tensor diff.
value (numpy.ndarray): tensor diff.
"""
def __init__(self, tolerance=0, stats=None, value=None):
self._tolerance = tolerance
self._stats = stats
self._value = value

@staticmethod
def validate_dims_format(dims):
"""
Validate correct of format of dimension parameter.
@property
def tolerance(self):
"""Get tolerance of TensorComparison."""
return self._tolerance

Args:
dims (str): Dims of tensor. Its format is something like this "[0, 0, :, :]".
@property
def stats(self):
"""Get stats of tensor diff."""
return self._stats

Raises:
ParamValueError: If format of dims is not correct.
"""
if dims is not None:
if not isinstance(dims, str):
raise ParamTypeError(dims, str)
dims = dims.strip()
if not (dims.startswith('[') and dims.endswith(']')):
raise ParamValueError('The value: {} of dims must be '
'start with `[` and end with `]`.'.format(dims))
for dim in dims[1:-1].split(','):
dim = dim.strip()
if dim == ":":
continue
if dim.startswith('-'):
dim = dim[1:]
if not dim.isdigit():
raise ParamValueError('The value: {} of dims in the square brackets '
'must be int or `:`.'.format(dims))
def update(self, tolerance, value):
"""update tensor comparisons."""
self._tolerance = tolerance
self._value = value

@property
def value(self):
"""Get value of tensor diff."""
return self._value

def str_to_slice_or_int(input_str):
"""
Translate param from string to slice or int.

Args:
input_str (str): The string to be translated.

Returns:
Union[int, slice], the transformed param.
"""
try:
if ':' in input_str:
ret = slice(*map(lambda x: int(x.strip()) if x.strip() else None, input_str.split(':')))
else:
ret = int(input_str)
except ValueError:
raise ParamValueError("Invalid shape. Convert int from str failed. input_str: {}".format(input_str))
return ret


class TensorUtils:
"""Tensor Utils class."""

@staticmethod
def convert_array_from_str_dims(dims, limit=0):
def parse_shape(shape, limit=0):
"""
Convert string of dims data to array.
Parse shape from str.

Args:
dims (str): Specify dims of tensor.
limit (int): The max flexible dimension count, default value is 0 which means that there is no limitation.
shape (str): Specify shape of tensor.
limit (int): The max dimensions specified. Default value is 0 which means that there is no limitation.

Returns:
list, a string like this: "[0, 0, :, :]" will convert to this value: [0, 0, None, None].
Union[None, tuple], a string like this: "[0, 0, 1:10, :]" will convert to this value:
(0, 0, slice(1, 10, None), slice(None, None, None)].

Raises:
ParamValueError, If flexible dimensions exceed limit value.
ParamValueError, If type of shape is not str or format is not correct or exceed specified dimensions.
"""
dims = dims.strip().lstrip('[').rstrip(']')
dims_list = []
count = 0
for dim in dims.split(','):
dim = dim.strip()
if dim == ':':
dims_list.append(None)
count += 1
else:
dims_list.append(to_int(dim, "dim"))
if limit and count > limit:
raise ParamValueError("Flexible dimensions cannot exceed limit value: {}, size: {}"
.format(limit, count))
return dims_list
if shape is None:
return shape
if not (isinstance(shape, str) and shape.strip().startswith('[') and shape.strip().endswith(']')):
raise ParamValueError("Invalid shape. The type of shape should be str and start with `[` and "
"end with `]`. Received: {}.".format(shape))
shape = shape.strip()[1:-1]
dimension_size = sum(1 for dim in shape.split(',') if dim.count(':'))
if limit and dimension_size > limit:
raise ParamValueError("Invalid shape. At most {} dimensions are specified. Received: {}"
.format(limit, shape))
parsed_shape = tuple(
str_to_slice_or_int(dim.strip()) for dim in shape.split(',')) if shape else tuple()
return parsed_shape

@staticmethod
def get_specific_dims_data(ndarray, dims, tensor_dims):
def get_specific_dims_data(ndarray, dims):
"""
Get specific dims data.

Args:
ndarray (numpy.ndarray): An ndarray of numpy.
dims (list): A list of specific dims.
tensor_dims (list): A list of tensor dims.
dims (tuple): A tuple of specific dims.

Returns:
numpy.ndarray, an ndarray of specific dims tensor data.

Raises:
ParamValueError, If the length of param dims is not equal to the length of tensor dims or
the index of param dims out of range.
ParamValueError, If the length of param dims is not equal to the length of tensor dims.
IndexError, If the param dims and tensor shape is unmatched.
"""
if len(dims) != len(tensor_dims):
raise ParamValueError("The length of param dims: {}, is not equal to the "
"length of tensor dims: {}.".format(len(dims), len(tensor_dims)))
indices = []
for k, d in enumerate(dims):
if d is not None:
if d >= tensor_dims[k]:
raise ParamValueError("The index: {} of param dims out of range: {}.".format(d, tensor_dims[k]))
indices.append(d)
else:
indices.append(slice(0, tensor_dims[k]))
result = ndarray[tuple(indices)]
if len(ndarray.shape) != len(dims):
raise ParamValueError("Invalid dims. The length of param dims and tensor shape should be the same.")
try:
result = ndarray[dims]
except IndexError:
raise ParamValueError("Invalid shape. Shape unmatched. Received: {}, tensor shape: {}"
.format(dims, ndarray.shape))
# Make sure the return type is numpy.ndarray.
if not isinstance(result, np.ndarray):
result = np.array(result)
@@ -233,15 +250,17 @@ class TensorUtils:
return statistics

@staticmethod
def get_statistics_dict(stats):
def get_statistics_dict(stats, overall_stats):
"""
Get statistics dict according to statistics value.

Args:
stats (Statistics): An instance of Statistics.
stats (Statistics): An instance of Statistics for sliced tensor.
overall_stats (Statistics): An instance of Statistics for whole tensor.

Returns:
dict, a dict including 'max', 'min', 'avg', 'count', 'nan_count', 'neg_inf_count', 'pos_inf_count'.
dict, a dict including 'max', 'min', 'avg', 'count',
'nan_count', 'neg_inf_count', 'pos_inf_count', 'overall_max', 'overall_min'.
"""
statistics = {
"max": float(stats.max),
@@ -250,7 +269,9 @@ class TensorUtils:
"count": stats.count,
"nan_count": stats.nan_count,
"neg_inf_count": stats.neg_inf_count,
"pos_inf_count": stats.pos_inf_count
"pos_inf_count": stats.pos_inf_count,
"overall_max": float(overall_stats.max),
"overall_min": float(overall_stats.min)
}
return statistics

@@ -274,7 +295,8 @@ class TensorUtils:

Raises:
ParamTypeError: If the type of these two tensors is not the numpy.ndarray.
ParamValueError: If the shape or dtype is not the same of these two tensors.
ParamValueError: If the shape or dtype is not the same of these two tensors or
the tolerance should be between 0 and 1.
"""
if not isinstance(first_tensor, np.ndarray):
raise ParamTypeError('first_tensor', np.ndarray)
@@ -289,6 +311,9 @@ class TensorUtils:
if first_tensor.dtype != second_tensor.dtype:
raise ParamValueError("the dtype: {} of first tensor is not equal to dtype: {} of second tensor."
.format(first_tensor.dtype, second_tensor.dtype))
# Make sure tolerance is between 0 and 1.
if tolerance < 0 or tolerance > 1:
raise ParamValueError("the tolerance should be between 0 and 1, but got {}".format(tolerance))

diff_tensor = np.subtract(first_tensor, second_tensor)
stats = TensorUtils.get_statistics_from_tensor(diff_tensor)


+ 10
- 7
tests/ut/datavisual/processors/test_tensor_processor.py View File

@@ -150,8 +150,7 @@ class TestTensorProcessor:
processor.get_tensors([self._train_id], [test_tag_name], step='1', dims='[0,:]', detail='data')

assert exc_info.value.error_code == '50540002'
assert "Invalid parameter value. The length of param dims: 2, is not equal to the length of tensor dims: 4" \
in exc_info.value.message
assert "The length of param dims and tensor shape should be the same" in exc_info.value.message

@pytest.mark.usefixtures('load_tensor_record')
def test_get_tensor_data_with_exceed_two_dims(self):
@@ -163,8 +162,7 @@ class TestTensorProcessor:
processor.get_tensors([self._train_id], [test_tag_name], step='1', dims='[0,:,:,:]', detail='data')

assert exc_info.value.error_code == '50540002'
assert "Invalid parameter value. Flexible dimensions cannot exceed limit value: 2, size: 3" \
in exc_info.value.message
assert "Invalid shape. At most 2 dimensions are specified" in exc_info.value.message

@pytest.mark.usefixtures('load_tensor_record')
def test_get_tensor_data_success(self):
@@ -172,7 +170,7 @@ class TestTensorProcessor:
test_tag_name = self._complete_tag_name

processor = TensorProcessor(self._mock_data_manager)
results = processor.get_tensors([self._train_id], [test_tag_name], step='1', dims='[0,0,:,:]', detail='data')
results = processor.get_tensors([self._train_id], [test_tag_name], step='1', dims='[0,0,:-1,:]', detail='data')

recv_metadata = results.get('tensors')[0].get("values")

@@ -182,7 +180,11 @@ class TestTensorProcessor:
dims = expected_values.get('value').get("dims")
expected_data = np.array(expected_values.get('value').get("float_data")).reshape(dims)
recv_tensor = np.array(recv_values.get('value').get("data"))
expected_tensor = TensorUtils.get_specific_dims_data(expected_data, [0, 0, None, None], dims)
expected_tensor = TensorUtils.get_specific_dims_data(
expected_data, (0, 0, slice(None, -1, None), slice(None)))
# Compare tensor shape when recv_tensor shape is not empty.
if recv_tensor.shape != (0,):
assert recv_tensor.shape == expected_tensor.shape
assert np.sum(np.isclose(recv_tensor, expected_tensor, rtol=1e-6) == 0) == 0

@pytest.mark.usefixtures('load_tensor_record')
@@ -200,7 +202,8 @@ class TestTensorProcessor:
assert recv_values.get('step') == expected_values.get('step')
expected_data = expected_values.get('value').get("float_data")
expected_statistic_instance = TensorUtils.get_statistics_from_tensor(expected_data)
expected_statistic = TensorUtils.get_statistics_dict(expected_statistic_instance)
expected_statistic = TensorUtils.get_statistics_dict(stats=expected_statistic_instance,
overall_stats=expected_statistic_instance)
recv_statistic = recv_values.get('value').get("statistics")
assert recv_statistic.get("max") - expected_statistic.get("max") < 1e-6
assert recv_statistic.get("min") - expected_statistic.get("min") < 1e-6


Loading…
Cancel
Save