| @@ -93,7 +93,9 @@ class EventsData: | |||
| with self._reservoir_mutex_lock: | |||
| if tag not in self._reservoir_by_tag: | |||
| reservoir_size = self._get_reservoir_size(tensor_event.plugin_name) | |||
| self._reservoir_by_tag[tag] = reservoir.Reservoir(reservoir_size) | |||
| self._reservoir_by_tag[tag] = reservoir.ReservoirFactory().create_reservoir( | |||
| plugin_name, reservoir_size | |||
| ) | |||
| tensor = _Tensor(wall_time=tensor_event.wall_time, | |||
| step=tensor_event.step, | |||
| @@ -0,0 +1,98 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Histogram data container.""" | |||
| import math | |||
| from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Summary | |||
| def _mask_invalid_number(num): | |||
| """Mask invalid number to 0.""" | |||
| if math.isnan(num) or math.isinf(num): | |||
| return type(num)(0) | |||
| return num | |||
| class HistogramContainer: | |||
| """ | |||
| Histogram data container. | |||
| Args: | |||
| histogram_message (Summary.Histogram): Histogram message in summary file. | |||
| """ | |||
| def __init__(self, histogram_message: Summary.Histogram): | |||
| self._msg = histogram_message | |||
| self._original_buckets = tuple((bucket.left, bucket.width, bucket.count) for bucket in self._msg.buckets) | |||
| self._max = _mask_invalid_number(histogram_message.max) | |||
| self._min = _mask_invalid_number(histogram_message.min) | |||
| self._visual_max = self._max | |||
| self._visual_min = self._min | |||
| # default bin number | |||
| self._visual_bins = 10 | |||
| self._count = self._msg.count | |||
| # Note that tuple is immutable, so sharing tuple is often safe. | |||
| self._re_sampled_buckets = self._original_buckets | |||
| @property | |||
| def max(self): | |||
| """Gets max value of the tensor.""" | |||
| return self._max | |||
| @property | |||
| def min(self): | |||
| """Gets min value of the tensor.""" | |||
| return self._min | |||
| @property | |||
| def count(self): | |||
| """Gets valid number count of the tensor.""" | |||
| return self._count | |||
| @property | |||
| def original_msg(self): | |||
| """Get original proto message""" | |||
| return self._msg | |||
| def set_visual_range(self, max_val: float, min_val: float, bins: int) -> None: | |||
| """ | |||
| Sets visual range for later re-sampling. | |||
| It's caller's duty to ensure input is valid. | |||
| Args: | |||
| max_val (float): Max value for visual histogram. | |||
| min_val (float): Min value for visual histogram. | |||
| bins (int): Bins number for visual histogram. | |||
| """ | |||
| self._visual_max = max_val | |||
| self._visual_min = min_val | |||
| self._visual_bins = bins | |||
| # mark _re_sampled_buckets to empty | |||
| self._re_sampled_buckets = () | |||
| def _re_sample_buckets(self): | |||
| # Will call re-sample logic in later PR. | |||
| self._re_sampled_buckets = self._original_buckets | |||
| def buckets(self): | |||
| """ | |||
| Get visual buckets instead of original buckets. | |||
| """ | |||
| if not self._re_sampled_buckets: | |||
| self._re_sample_buckets() | |||
| return self._re_sampled_buckets | |||
| @@ -36,6 +36,7 @@ from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summar | |||
| from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 | |||
| from mindinsight.datavisual.utils import crc32 | |||
| from mindinsight.utils.exceptions import UnknownError | |||
| from mindinsight.datavisual.data_transform.histogram_container import HistogramContainer | |||
| HEADER_SIZE = 8 | |||
| CRC_STR_SIZE = 4 | |||
| @@ -235,7 +236,7 @@ class MSDataLoader: | |||
| self._events_data.add_tensor_event(tensor_event) | |||
| if value.HasField('histogram'): | |||
| histogram_msg = value.histogram | |||
| histogram_msg = HistogramContainer(value.histogram) | |||
| tag = '{}/{}'.format(value.tag, PluginNameEnum.HISTOGRAM.value) | |||
| tensor_event = TensorEvent(wall_time=event.wall_time, | |||
| step=event.step, | |||
| @@ -16,7 +16,9 @@ | |||
| import random | |||
| import threading | |||
| import math | |||
| from mindinsight.datavisual.common.enums import PluginNameEnum | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| @@ -106,3 +108,118 @@ class Reservoir: | |||
| round(self._sample_counter * sample_remaining_rate)) | |||
| return remove_size | |||
| class _VisualRange: | |||
| """Simple helper class to merge visual ranges.""" | |||
| def __init__(self): | |||
| self._max = 0.0 | |||
| self._min = 0.0 | |||
| self._updated = False | |||
| def update(self, max_val: float, min_val: float) -> None: | |||
| """ | |||
| Merge visual range with given range. | |||
| Args: | |||
| max_val (float): Max value of given range. | |||
| min_val (float): Min value of given range. | |||
| """ | |||
| if not self._updated: | |||
| self._max = max_val | |||
| self._min = min_val | |||
| self._updated = True | |||
| return | |||
| if max_val > self._max: | |||
| self._max = max_val | |||
| if min_val < self._min: | |||
| self._min = min_val | |||
| @property | |||
| def max(self): | |||
| """Gets max value of current range.""" | |||
| return self._max | |||
| @property | |||
| def min(self): | |||
| """Gets min value of current range.""" | |||
| return self._min | |||
| class HistogramReservoir(Reservoir): | |||
| """ | |||
| Reservoir for histogram, which needs updating range over all steps. | |||
| Args: | |||
| size (int): Container Size. If the size is 0, the container is not limited. | |||
| """ | |||
| def __init__(self, size): | |||
| super().__init__(size) | |||
| def samples(self): | |||
| """Return all stored samples.""" | |||
| with self._mutex: | |||
| # calc visual range | |||
| visual_range = _VisualRange() | |||
| max_count = 0 | |||
| for sample in self._samples: | |||
| histogram = sample.value | |||
| if histogram.count == 0: | |||
| # ignore empty tensor | |||
| continue | |||
| max_count = max(histogram.count, max_count) | |||
| visual_range.update(histogram.max, histogram.min) | |||
| bins = self._calc_bins(max_count) | |||
| # update visual range | |||
| for sample in self._samples: | |||
| histogram = sample.value | |||
| histogram.set_visual_range(visual_range.max, visual_range.min, bins) | |||
| return list(self._samples) | |||
| def _calc_bins(self, count): | |||
| """ | |||
| Calculates experience-based optimal bins number. | |||
| To suppress re-sample bias, there should be enough number in each bin. So we calc bin numbers according to | |||
| count. For very small count(1 - 10), we assign carefully chosen number. For large count, we tried to make | |||
| sure there are 9-10 numbers in each bucket on average. Too many bins will also distract users, so we set max | |||
| number of bins to 30. | |||
| """ | |||
| number_per_bucket = 10 | |||
| max_bins = 30 | |||
| if not count: | |||
| return 1 | |||
| if count <= 5: | |||
| return 2 | |||
| if count <= 10: | |||
| return 3 | |||
| if count <= 280: | |||
| # note that math.ceil(281/10) + 1 = 30 | |||
| return math.ceil(count / number_per_bucket) + 1 | |||
| return max_bins | |||
| class ReservoirFactory: | |||
| """Factory class to get reservoir instances.""" | |||
| def create_reservoir(self, plugin_name: str, size: int) -> Reservoir: | |||
| """ | |||
| Creates reservoir for given plugin name. | |||
| Args: | |||
| plugin_name (str): Plugin name | |||
| size (int): Container Size. If the size is 0, the container is not limited. | |||
| Returns: | |||
| Reservoir, reservoir instance for given plugin name. | |||
| """ | |||
| if plugin_name == PluginNameEnum.HISTOGRAM.value: | |||
| return HistogramReservoir(size) | |||
| return Reservoir(size) | |||
| @@ -53,9 +53,8 @@ class HistogramProcessor(BaseProcessor): | |||
| histograms = [] | |||
| for tensor in tensors: | |||
| buckets = [] | |||
| for bucket in tensor.value.buckets: | |||
| buckets.append([bucket.left, bucket.width, bucket.count]) | |||
| histogram = tensor.value | |||
| buckets = histogram.buckets() | |||
| histograms.append({ | |||
| "wall_time": tensor.wall_time, | |||
| "step": tensor.step, | |||
| @@ -0,0 +1,34 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test histogram.""" | |||
| import unittest.mock as mock | |||
| from mindinsight.datavisual.data_transform import histogram_container as hist | |||
| class TestHistogram: | |||
| """Test histogram.""" | |||
| def test_get_buckets(self): | |||
| """Test get buckets.""" | |||
| mocked_input = mock.MagicMock() | |||
| mocked_bucket = mock.MagicMock() | |||
| mocked_bucket.left = 0 | |||
| mocked_bucket.width = 1 | |||
| mocked_bucket.count = 1 | |||
| mocked_input.buckets = [mocked_bucket] | |||
| histogram = hist.HistogramContainer(mocked_input) | |||
| histogram.set_visual_range(max_val=1, min_val=0, bins=1) | |||
| buckets = histogram.buckets() | |||
| assert len(buckets) == 1 | |||
| @@ -0,0 +1,37 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Test reservoir.""" | |||
| import unittest.mock as mock | |||
| import mindinsight.datavisual.data_transform.reservoir as reservoir | |||
| class TestHistogramReservoir: | |||
| """Test histogram reservoir.""" | |||
| def test_samples(self): | |||
| """Test get samples.""" | |||
| my_reservoir = reservoir.ReservoirFactory().create_reservoir(reservoir.PluginNameEnum.HISTOGRAM.value, size=10) | |||
| sample1 = mock.MagicMock() | |||
| sample1.value.count = 1 | |||
| sample1.value.max = 102 | |||
| sample1.value.min = 101 | |||
| sample2 = mock.MagicMock() | |||
| sample2.value.count = 2 | |||
| sample2.value.max = 102 | |||
| sample2.value.min = 101 | |||
| my_reservoir.add_sample(sample1) | |||
| my_reservoir.add_sample(sample2) | |||
| samples = my_reservoir.samples() | |||
| assert len(samples) == 2 | |||