GitOrigin-RevId: 5727f63560
tags/v1.0.0-rc1
| @@ -92,6 +92,25 @@ class QATModule(Module): | |||
| else: | |||
| return self.act_observer.get_dtype() | |||
| def _get_qparams(self, fake_quant: FakeQuantize, observer: Observer): | |||
| if hasattr(fake_quant, "get_qparams"): | |||
| return fake_quant.get_qparams() | |||
| elif observer is not None: | |||
| return observer.get_qparams() | |||
| return None | |||
| def get_weight_qparams(self): | |||
| r""" | |||
| Get weight's quantization parameters. | |||
| """ | |||
| return self._get_qparams(self.weight_fake_quant, self.weight_observer) | |||
| def get_activation_qparams(self): | |||
| r""" | |||
| Get activation's quantization parameters. | |||
| """ | |||
| return self._get_qparams(self.act_fake_quant, self.act_observer) | |||
| @classmethod | |||
| @abstractmethod | |||
| def from_float_module(cls, float_module: Module): | |||
| @@ -8,7 +8,7 @@ | |||
| from .fake_quant import FakeQuantize | |||
| from .internal_fake_quant import * | |||
| from .observer import HistogramObserver, Observer, ObserverMode | |||
| from .observer import HistogramObserver, Observer | |||
| from .qconfig import ( | |||
| QConfig, | |||
| calibration_qconfig, | |||
| @@ -16,3 +16,4 @@ from .qconfig import ( | |||
| min_max_fakequant_qconfig, | |||
| tqt_quant_qconfig, | |||
| ) | |||
| from .utils import QuantMode | |||
| @@ -15,7 +15,8 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||
| from ..core import Buffer, Function, Parameter | |||
| from ..jit import sideeffect | |||
| from ..module import Module | |||
| from .observer import ObserverMode, Round | |||
| from .observer import Round | |||
| from .utils import QuantMode, get_qparam_dict | |||
| class _FakeQuantize(Module): | |||
| @@ -121,8 +122,18 @@ class TQT(_FakeQuantize): | |||
| F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | |||
| return inp | |||
| def get_qparams(self): | |||
| qdict = get_qparam_dict(QuantMode.TQT) | |||
| qdict["scale"] = 2 ** self.scale | |||
| return qdict | |||
| def get_dtype(self): | |||
| return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None) | |||
| q_dict = self.get_qparams() | |||
| scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0] | |||
| zero_point = ( | |||
| None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0] | |||
| ) | |||
| return get_quantized_dtype(self.dtype, scale, zero_point) | |||
| class FakeQuantize(_FakeQuantize): | |||
| @@ -131,7 +142,7 @@ class FakeQuantize(_FakeQuantize): | |||
| """ | |||
| def fake_quant_forward(self, inp, q_dict): | |||
| if q_dict["mode"] == ObserverMode.SYMMERTIC: | |||
| if q_dict["mode"] == QuantMode.SYMMERTIC: | |||
| scale = q_dict["scale"] | |||
| # Quant | |||
| oup = Round()(inp / scale) | |||
| @@ -16,6 +16,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||
| from ..core import Buffer, Function, tensor | |||
| from ..jit import sideeffect | |||
| from ..module import Module | |||
| from .utils import QuantMode, get_qparam_dict | |||
| class Round(Function): | |||
| @@ -81,29 +82,10 @@ class Observer(Module): | |||
| pass | |||
| class ObserverMode(Enum): | |||
| SYMMERTIC = 1 | |||
| ASYMMERTIC = 2 | |||
| def create_observer_dict(mode): | |||
| if mode == ObserverMode.SYMMERTIC: | |||
| return { | |||
| "mode": ObserverMode.SYMMERTIC, | |||
| "scale": None, | |||
| } | |||
| else: | |||
| return { | |||
| "mode": ObserverMode.ASYMMERTIC, | |||
| "scale": None, | |||
| "zero_point": None, | |||
| } | |||
| class MinMaxObserver(Observer): | |||
| def __init__( | |||
| self, | |||
| mode=ObserverMode.SYMMERTIC, | |||
| mode=QuantMode.SYMMERTIC, | |||
| eps=0.00001, | |||
| dtype="qint8", | |||
| narrow_range: bool = False, | |||
| @@ -117,10 +99,10 @@ class MinMaxObserver(Observer): | |||
| def _calculate_qparams(self, inp_min_val, inp_max_val): | |||
| min_val = F.minimum(0.0, inp_min_val) | |||
| max_val = F.maximum(0.0, inp_max_val) | |||
| q_dict = create_observer_dict(self.mode) | |||
| q_dict = get_qparam_dict(self.mode) | |||
| q_dict["min_val"] = inp_min_val | |||
| q_dict["max_val"] = inp_max_val | |||
| if self.mode == ObserverMode.SYMMERTIC: | |||
| if self.mode == QuantMode.SYMMERTIC: | |||
| symmetric_max_vals = F.maximum(-min_val, max_val) | |||
| # use maximun to avoid scale too small at the begin | |||
| q_dict["scale"] = F.maximum( | |||
| @@ -166,7 +148,7 @@ class ExponentialMovingAverageObserver(MinMaxObserver): | |||
| def __init__( | |||
| self, | |||
| momentum=0.9, | |||
| mode=ObserverMode.SYMMERTIC, | |||
| mode=QuantMode.SYMMERTIC, | |||
| eps=0.00001, | |||
| dtype="qint8", | |||
| narrow_range: bool = False, | |||
| @@ -204,7 +186,7 @@ class HistogramObserver(MinMaxObserver): | |||
| self, | |||
| bins=2048, | |||
| upsample_rate=128, | |||
| mode=ObserverMode.SYMMERTIC, | |||
| mode=QuantMode.SYMMERTIC, | |||
| eps=0.00001, | |||
| dtype="qint8", | |||
| narrow_range: bool = False, | |||
| @@ -6,6 +6,7 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from enum import Enum | |||
| from functools import partial, update_wrapper, wraps | |||
| @@ -21,3 +22,24 @@ def register_method_to_class(cls): | |||
| return func | |||
| return decorator | |||
| class QuantMode(Enum): | |||
| SYMMERTIC = 1 | |||
| ASYMMERTIC = 2 | |||
| TQT = 3 | |||
| qparam_dict = { | |||
| QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None,}, | |||
| QuantMode.ASYMMERTIC: { | |||
| "mode": QuantMode.ASYMMERTIC, | |||
| "scale": None, | |||
| "zero_point": None, | |||
| }, | |||
| QuantMode.TQT: {"mode": QuantMode.TQT, "scale": None,}, | |||
| } | |||
| def get_qparam_dict(mode): | |||
| return qparam_dict.get(mode, None) | |||