|
|
|
@@ -17,7 +17,6 @@ import numpy as np |
|
|
|
|
|
|
|
from mindinsight.datavisual.common.log import logger |
|
|
|
from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket |
|
|
|
from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 |
|
|
|
from mindinsight.datavisual.utils.utils import calc_histogram_bins |
|
|
|
from mindinsight.utils.exceptions import ParamValueError |
|
|
|
|
|
|
|
@@ -192,10 +191,8 @@ class TensorContainer: |
|
|
|
self._stats = get_statistics_from_tensor(self._np_array) |
|
|
|
original_buckets = calc_original_buckets(self._np_array, self._stats) |
|
|
|
self._count = sum(bucket.count for bucket in original_buckets) |
|
|
|
# convert the type of max and min value to np.float64 so that it cannot overflow |
|
|
|
# when calculating width of histogram. |
|
|
|
self._max = np.float64(self._stats.max) |
|
|
|
self._min = np.float64(self._stats.min) |
|
|
|
self._max = self._stats.max |
|
|
|
self._min = self._stats.min |
|
|
|
self._histogram = Histogram(tuple(original_buckets), self._max, self._min, self._count) |
|
|
|
|
|
|
|
@property |
|
|
|
@@ -257,9 +254,4 @@ class TensorContainer: |
|
|
|
Returns: |
|
|
|
numpy.ndarray, ndarray of tensor. |
|
|
|
""" |
|
|
|
data_type_str = anf_ir_pb2.DataType.Name(self.data_type) |
|
|
|
if data_type_str == 'DT_FLOAT16': |
|
|
|
return np.array(tuple(tensor), dtype=np.float16).reshape(self.dims) |
|
|
|
if data_type_str == 'DT_FLOAT32': |
|
|
|
return np.array(tuple(tensor), dtype=np.float32).reshape(self.dims) |
|
|
|
return np.array(tuple(tensor)).reshape(self.dims) |