GitOrigin-RevId: ed6af9b98d
tags/v1.3.0
| @@ -153,11 +153,11 @@ class ExponentialMovingAverageObserver(MinMaxObserver): | |||||
| **kwargs | **kwargs | ||||
| ): | ): | ||||
| super().__init__(mode, eps, dtype, narrow_range, **kwargs) | super().__init__(mode, eps, dtype, narrow_range, **kwargs) | ||||
| self.momentum = Tensor(momentum) | |||||
| self.momentum = Tensor(momentum, dtype="float32") | |||||
| self.runtime_momentum = Tensor(0.0) | self.runtime_momentum = Tensor(0.0) | ||||
| def set_momentum(self, momentum): | def set_momentum(self, momentum): | ||||
| self.momentum._reset(momentum) | |||||
| self.momentum = Tenosr(momentum, dtype="float32") | |||||
| def forward(self, x_orig): | def forward(self, x_orig): | ||||
| if self.enabled: | if self.enabled: | ||||
| @@ -439,9 +439,9 @@ class HistogramObserver(MinMaxObserver): | |||||
| self.bins, | self.bins, | ||||
| ) | ) | ||||
| self.histogram._reset(new_histogram) | |||||
| self.min_val._reset(new_min) | |||||
| self.max_val._reset(new_max) | |||||
| self.histogram = Tensor(new_histogram, dtype="float32") | |||||
| self.min_val = Tensor(new_min, dtype="float32") | |||||
| self.max_val = Tensor(new_max, dtype="float32") | |||||
| def forward(self, x_orig): | def forward(self, x_orig): | ||||
| self.sideeffect_forward(x_orig) | self.sideeffect_forward(x_orig) | ||||
| @@ -8,6 +8,7 @@ import megengine.distributed as dist | |||||
| from megengine.distributed.helper import get_device_count_by_fork | from megengine.distributed.helper import get_device_count_by_fork | ||||
| from megengine.quantization.observer import ( | from megengine.quantization.observer import ( | ||||
| ExponentialMovingAverageObserver, | ExponentialMovingAverageObserver, | ||||
| HistogramObserver, | |||||
| MinMaxObserver, | MinMaxObserver, | ||||
| Observer, | Observer, | ||||
| PassiveObserver, | PassiveObserver, | ||||
| @@ -44,6 +45,16 @@ def test_exponential_moving_average_observer(): | |||||
| np.testing.assert_allclose(m.max_val.numpy(), expected_max) | np.testing.assert_allclose(m.max_val.numpy(), expected_max) | ||||
| def test_histogram_observer(): | |||||
| x = np.random.rand(3, 3, 3, 3).astype("float32") | |||||
| np_min, np_max = x.min(), x.max() | |||||
| x = mge.tensor(x) | |||||
| m = HistogramObserver() | |||||
| m(x) | |||||
| np.testing.assert_allclose(m.min_val.numpy(), np_min) | |||||
| np.testing.assert_allclose(m.max_val.numpy(), np_max) | |||||
| def test_passive_observer(): | def test_passive_observer(): | ||||
| q_dict = {"scale": mge.tensor(1.0)} | q_dict = {"scale": mge.tensor(1.0)} | ||||
| m = PassiveObserver(q_dict, "qint8") | m = PassiveObserver(q_dict, "qint8") | ||||