GitOrigin-RevId: 5727f63560
tags/v1.0.0-rc1
| @@ -92,6 +92,25 @@ class QATModule(Module): | |||||
| else: | else: | ||||
| return self.act_observer.get_dtype() | return self.act_observer.get_dtype() | ||||
| def _get_qparams(self, fake_quant: FakeQuantize, observer: Observer): | |||||
| if hasattr(fake_quant, "get_qparams"): | |||||
| return fake_quant.get_qparams() | |||||
| elif observer is not None: | |||||
| return observer.get_qparams() | |||||
| return None | |||||
| def get_weight_qparams(self): | |||||
| r""" | |||||
| Get weight's quantization parameters. | |||||
| """ | |||||
| return self._get_qparams(self.weight_fake_quant, self.weight_observer) | |||||
| def get_activation_qparams(self): | |||||
| r""" | |||||
| Get activation's quantization parameters. | |||||
| """ | |||||
| return self._get_qparams(self.act_fake_quant, self.act_observer) | |||||
| @classmethod | @classmethod | ||||
| @abstractmethod | @abstractmethod | ||||
| def from_float_module(cls, float_module: Module): | def from_float_module(cls, float_module: Module): | ||||
| @@ -8,7 +8,7 @@ | |||||
| from .fake_quant import FakeQuantize | from .fake_quant import FakeQuantize | ||||
| from .internal_fake_quant import * | from .internal_fake_quant import * | ||||
| from .observer import HistogramObserver, Observer, ObserverMode | |||||
| from .observer import HistogramObserver, Observer | |||||
| from .qconfig import ( | from .qconfig import ( | ||||
| QConfig, | QConfig, | ||||
| calibration_qconfig, | calibration_qconfig, | ||||
| @@ -16,3 +16,4 @@ from .qconfig import ( | |||||
| min_max_fakequant_qconfig, | min_max_fakequant_qconfig, | ||||
| tqt_quant_qconfig, | tqt_quant_qconfig, | ||||
| ) | ) | ||||
| from .utils import QuantMode | |||||
| @@ -15,7 +15,8 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||||
| from ..core import Buffer, Function, Parameter | from ..core import Buffer, Function, Parameter | ||||
| from ..jit import sideeffect | from ..jit import sideeffect | ||||
| from ..module import Module | from ..module import Module | ||||
| from .observer import ObserverMode, Round | |||||
| from .observer import Round | |||||
| from .utils import QuantMode, get_qparam_dict | |||||
| class _FakeQuantize(Module): | class _FakeQuantize(Module): | ||||
| @@ -121,8 +122,18 @@ class TQT(_FakeQuantize): | |||||
| F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | ||||
| return inp | return inp | ||||
| def get_qparams(self): | |||||
| qdict = get_qparam_dict(QuantMode.TQT) | |||||
| qdict["scale"] = 2 ** self.scale | |||||
| return qdict | |||||
| def get_dtype(self): | def get_dtype(self): | ||||
| return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None) | |||||
| q_dict = self.get_qparams() | |||||
| scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0] | |||||
| zero_point = ( | |||||
| None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0] | |||||
| ) | |||||
| return get_quantized_dtype(self.dtype, scale, zero_point) | |||||
| class FakeQuantize(_FakeQuantize): | class FakeQuantize(_FakeQuantize): | ||||
| @@ -131,7 +142,7 @@ class FakeQuantize(_FakeQuantize): | |||||
| """ | """ | ||||
| def fake_quant_forward(self, inp, q_dict): | def fake_quant_forward(self, inp, q_dict): | ||||
| if q_dict["mode"] == ObserverMode.SYMMERTIC: | |||||
| if q_dict["mode"] == QuantMode.SYMMERTIC: | |||||
| scale = q_dict["scale"] | scale = q_dict["scale"] | ||||
| # Quant | # Quant | ||||
| oup = Round()(inp / scale) | oup = Round()(inp / scale) | ||||
| @@ -16,6 +16,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||||
| from ..core import Buffer, Function, tensor | from ..core import Buffer, Function, tensor | ||||
| from ..jit import sideeffect | from ..jit import sideeffect | ||||
| from ..module import Module | from ..module import Module | ||||
| from .utils import QuantMode, get_qparam_dict | |||||
| class Round(Function): | class Round(Function): | ||||
| @@ -81,29 +82,10 @@ class Observer(Module): | |||||
| pass | pass | ||||
| class ObserverMode(Enum): | |||||
| SYMMERTIC = 1 | |||||
| ASYMMERTIC = 2 | |||||
| def create_observer_dict(mode): | |||||
| if mode == ObserverMode.SYMMERTIC: | |||||
| return { | |||||
| "mode": ObserverMode.SYMMERTIC, | |||||
| "scale": None, | |||||
| } | |||||
| else: | |||||
| return { | |||||
| "mode": ObserverMode.ASYMMERTIC, | |||||
| "scale": None, | |||||
| "zero_point": None, | |||||
| } | |||||
| class MinMaxObserver(Observer): | class MinMaxObserver(Observer): | ||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| mode=ObserverMode.SYMMERTIC, | |||||
| mode=QuantMode.SYMMERTIC, | |||||
| eps=0.00001, | eps=0.00001, | ||||
| dtype="qint8", | dtype="qint8", | ||||
| narrow_range: bool = False, | narrow_range: bool = False, | ||||
| @@ -117,10 +99,10 @@ class MinMaxObserver(Observer): | |||||
| def _calculate_qparams(self, inp_min_val, inp_max_val): | def _calculate_qparams(self, inp_min_val, inp_max_val): | ||||
| min_val = F.minimum(0.0, inp_min_val) | min_val = F.minimum(0.0, inp_min_val) | ||||
| max_val = F.maximum(0.0, inp_max_val) | max_val = F.maximum(0.0, inp_max_val) | ||||
| q_dict = create_observer_dict(self.mode) | |||||
| q_dict = get_qparam_dict(self.mode) | |||||
| q_dict["min_val"] = inp_min_val | q_dict["min_val"] = inp_min_val | ||||
| q_dict["max_val"] = inp_max_val | q_dict["max_val"] = inp_max_val | ||||
| if self.mode == ObserverMode.SYMMERTIC: | |||||
| if self.mode == QuantMode.SYMMERTIC: | |||||
| symmetric_max_vals = F.maximum(-min_val, max_val) | symmetric_max_vals = F.maximum(-min_val, max_val) | ||||
| # use maximun to avoid scale too small at the begin | # use maximun to avoid scale too small at the begin | ||||
| q_dict["scale"] = F.maximum( | q_dict["scale"] = F.maximum( | ||||
| @@ -166,7 +148,7 @@ class ExponentialMovingAverageObserver(MinMaxObserver): | |||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| momentum=0.9, | momentum=0.9, | ||||
| mode=ObserverMode.SYMMERTIC, | |||||
| mode=QuantMode.SYMMERTIC, | |||||
| eps=0.00001, | eps=0.00001, | ||||
| dtype="qint8", | dtype="qint8", | ||||
| narrow_range: bool = False, | narrow_range: bool = False, | ||||
| @@ -204,7 +186,7 @@ class HistogramObserver(MinMaxObserver): | |||||
| self, | self, | ||||
| bins=2048, | bins=2048, | ||||
| upsample_rate=128, | upsample_rate=128, | ||||
| mode=ObserverMode.SYMMERTIC, | |||||
| mode=QuantMode.SYMMERTIC, | |||||
| eps=0.00001, | eps=0.00001, | ||||
| dtype="qint8", | dtype="qint8", | ||||
| narrow_range: bool = False, | narrow_range: bool = False, | ||||
| @@ -6,6 +6,7 @@ | |||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from enum import Enum | |||||
| from functools import partial, update_wrapper, wraps | from functools import partial, update_wrapper, wraps | ||||
| @@ -21,3 +22,24 @@ def register_method_to_class(cls): | |||||
| return func | return func | ||||
| return decorator | return decorator | ||||
| class QuantMode(Enum): | |||||
| SYMMERTIC = 1 | |||||
| ASYMMERTIC = 2 | |||||
| TQT = 3 | |||||
| qparam_dict = { | |||||
| QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None,}, | |||||
| QuantMode.ASYMMERTIC: { | |||||
| "mode": QuantMode.ASYMMERTIC, | |||||
| "scale": None, | |||||
| "zero_point": None, | |||||
| }, | |||||
| QuantMode.TQT: {"mode": QuantMode.TQT, "scale": None,}, | |||||
| } | |||||
| def get_qparam_dict(mode): | |||||
| return qparam_dict.get(mode, None) | |||||