GitOrigin-RevId: f8511f72ad
tags/v1.3.0
| @@ -6,12 +6,8 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import math | |||
| from typing import Iterable | |||
| import numpy as np | |||
| from .. import functional as F | |||
| from ..autodiff import Function | |||
| from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype | |||
| from ..module import Module | |||
| from ..tensor import Parameter, Tensor | |||
| @@ -72,20 +68,10 @@ class TQT(_FakeQuantize): | |||
| """ | |||
| def __init__( | |||
| self, | |||
| q_dict, | |||
| dtype: str, | |||
| narrow_range: bool = False, | |||
| enable: bool = True, | |||
| **kwargs | |||
| self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs | |||
| ): | |||
| super().__init__(dtype, narrow_range, enable, **kwargs) | |||
| assert ( | |||
| q_dict["mode"] == QuantMode.SYMMERTIC | |||
| ), "only symmetric quantization is supported by TQT" | |||
| if "scale" not in q_dict or q_dict["scale"] is None: | |||
| raise AssertionError("Can not get an initialized scale") | |||
| self.scale = Tensor(F.log(q_dict["scale"]) / math.log(2)) | |||
| self.scale = Parameter(0.0, dtype="float32") | |||
| def fake_quant_forward(self, inp, q_dict=None): | |||
| # when enable, TQT will do fakequant forward, finetune the scale | |||
| @@ -93,14 +79,22 @@ class TQT(_FakeQuantize): | |||
| def get_qparams(self): | |||
| q_dict = get_qparam_dict(QuantMode.SYMMERTIC) | |||
| q_dict["scale"] = 2 ** self.scale | |||
| q_dict["scale"] = 2 ** self.scale.detach() | |||
| return q_dict | |||
| def set_qparams(self, q_dict): | |||
| assert ( | |||
| q_dict["mode"] == QuantMode.SYMMERTIC | |||
| ), "only symmetric quantization is supported by TQT" | |||
| if "scale" not in q_dict or q_dict["scale"] is None: | |||
| raise AssertionError("Can not get an initialized scale") | |||
| self.scale._reset(F.log(q_dict["scale"]) / math.log(2)) | |||
| def get_dtype(self): | |||
| q_dict = self.get_qparams() | |||
| scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0] | |||
| scale = None if "scale" not in q_dict else q_dict["scale"].numpy() | |||
| zero_point = ( | |||
| None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0] | |||
| None if "zero_point" not in q_dict else q_dict["zero_point"].numpy() | |||
| ) | |||
| return get_quantized_dtype(self.dtype, scale, zero_point) | |||
| @@ -17,7 +17,7 @@ from ..distributed import WORLD, get_rank, is_distributed | |||
| from ..functional.distributed import all_reduce_max, all_reduce_min | |||
| from ..module import Module | |||
| from ..tensor import Tensor | |||
| from .utils import QuantMode, Round, get_qparam_dict | |||
| from .utils import QuantMode, get_qparam_dict | |||
| class Observer(Module): | |||
| @@ -110,7 +110,7 @@ class MinMaxObserver(Observer): | |||
| (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit | |||
| ) | |||
| # caculate zero_point | |||
| q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"])) | |||
| q_dict["zero_point"] = self.qmin - F.round(min_val / q_dict["scale"]) | |||
| return q_dict | |||
| @@ -453,12 +453,10 @@ class PassiveObserver(Observer): | |||
| This class can be set :attr:`scale` derectly. | |||
| """ | |||
| def __init__(self, q_dict, dtype: str, narrow_range: bool = False, **kwargs): | |||
| def __init__(self, dtype: str, narrow_range: bool = False, **kwargs): | |||
| super().__init__(dtype, narrow_range, **kwargs) | |||
| self.q_dict = deepcopy(q_dict) | |||
| if "scale" not in q_dict or q_dict["scale"] is None: | |||
| raise AssertionError("Can not get an initialized scale") | |||
| self.orig_scale = q_dict["scale"].numpy() | |||
| self.q_dict = None | |||
| self.orig_scale = None | |||
| @property | |||
| def scale(self): | |||
| @@ -472,6 +470,12 @@ class PassiveObserver(Observer): | |||
| def get_qparams(self): | |||
| return self.q_dict | |||
| def set_qparams(self, q_dict): | |||
| self.q_dict = deepcopy(q_dict) | |||
| if "scale" not in q_dict or q_dict["scale"] is None: | |||
| raise AssertionError("Can not get an initialized scale") | |||
| self.orig_scale = q_dict["scale"].numpy() | |||
| def forward(self, x): | |||
| r""" | |||
| Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`. | |||
| @@ -152,7 +152,10 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True): | |||
| module = deepcopy(module) | |||
| def safe_call(func, q_dict): | |||
| return func(q_dict=q_dict) if func is not None else None | |||
| inst = func() if func is not None else None | |||
| if inst is not None and getattr(inst, "set_qparams", None) is not None: | |||
| inst.set_qparams(q_dict) | |||
| return inst | |||
| for m in list(module._flatten(predicate=is_qat)): | |||
| if m.with_weight: | |||
| @@ -41,8 +41,8 @@ def test_exponential_moving_average_observer(): | |||
| m = ExponentialMovingAverageObserver(momentum=t) | |||
| m(mge.tensor(x1, dtype=np.float32)) | |||
| m(mge.tensor(x2, dtype=np.float32)) | |||
| np.testing.assert_allclose(m.min_val.numpy(), expected_min) | |||
| np.testing.assert_allclose(m.max_val.numpy(), expected_max) | |||
| np.testing.assert_allclose(m.min_val.numpy(), expected_min, atol=1e-5) | |||
| np.testing.assert_allclose(m.max_val.numpy(), expected_max, atol=1e-5) | |||
| def test_histogram_observer(): | |||
| @@ -57,7 +57,8 @@ def test_histogram_observer(): | |||
| def test_passive_observer(): | |||
| q_dict = {"scale": mge.tensor(1.0)} | |||
| m = PassiveObserver(q_dict, "qint8") | |||
| m = PassiveObserver("qint8") | |||
| m.set_qparams(q_dict) | |||
| assert m.orig_scale == 1.0 | |||
| assert m.scale == 1.0 | |||
| m.scale = 2.0 | |||