GitOrigin-RevId: 92389341be
tags/v0.6.0
| @@ -37,15 +37,14 @@ class QATModule(Module): | |||||
| Set quantization related configs with ``qconfig``, including | Set quantization related configs with ``qconfig``, including | ||||
| observer and fake_quant for weight and activation. | observer and fake_quant for weight and activation. | ||||
| """ | """ | ||||
| self.weight_observer = qconfig.weight_observer() | |||||
| self.act_observer = qconfig.act_observer() | |||||
| if qconfig.fake_quant is None: | |||||
| self.weight_fake_quant = None | |||||
| self.act_fake_quant = None | |||||
| else: | |||||
| self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype) | |||||
| self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype) | |||||
| def safe_call(func): | |||||
| return func() if func is not None else None | |||||
| self.weight_observer = safe_call(qconfig.weight_observer) | |||||
| self.act_observer = safe_call(qconfig.act_observer) | |||||
| self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) | |||||
| self.act_fake_quant = safe_call(qconfig.act_fake_quant) | |||||
| def _apply_fakequant_with_observer( | def _apply_fakequant_with_observer( | ||||
| self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | ||||
| @@ -19,7 +19,7 @@ from .observer import ObserverMode, Round | |||||
| class _FakeQuantize(Module): | class _FakeQuantize(Module): | ||||
| def __init__(self, dtype: str, enable: bool = True): | |||||
| def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | |||||
| super().__init__() | super().__init__() | ||||
| if not dtype in _metadata_dict.keys(): | if not dtype in _metadata_dict.keys(): | ||||
| raise ValueError( | raise ValueError( | ||||
| @@ -28,7 +28,10 @@ class _FakeQuantize(Module): | |||||
| ) | ) | ||||
| ) | ) | ||||
| self.dtype = dtype | self.dtype = dtype | ||||
| self.qmin = _metadata_dict[dtype].qmin | |||||
| self.narrow_range = narrow_range | |||||
| self.qmin = ( | |||||
| -_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin | |||||
| ) | |||||
| self.qmax = _metadata_dict[dtype].qmax | self.qmax = _metadata_dict[dtype].qmax | ||||
| self.enabled = enable | self.enabled = enable | ||||
| @@ -90,12 +93,12 @@ class TQT_Function(Function): | |||||
| class TQT(_FakeQuantize): | class TQT(_FakeQuantize): | ||||
| """ | """ | ||||
| TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds | |||||
| TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds | |||||
| for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks | for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks | ||||
| """ | """ | ||||
| def __init__(self, dtype: str, enable: bool = True): | |||||
| super().__init__(dtype, enable) | |||||
| def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | |||||
| super().__init__(dtype, narrow_range, enable) | |||||
| self.scale = Parameter(0.0, dtype=np.float32) | self.scale = Parameter(0.0, dtype=np.float32) | ||||
| def fake_quant_forward(self, inp, q_dict): | def fake_quant_forward(self, inp, q_dict): | ||||
| @@ -116,6 +119,11 @@ class TQT(_FakeQuantize): | |||||
| class FakeQuantize(_FakeQuantize): | class FakeQuantize(_FakeQuantize): | ||||
| r""" | r""" | ||||
| A module to do quant and dequant according to observer's scale and zero_point. | A module to do quant and dequant according to observer's scale and zero_point. | ||||
| :param dtype: A string indicating the target quantization type of input. | |||||
| :param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, | |||||
| instead of 1 greater. Usually True for weight and False for activation. | |||||
| :param enable: Whether do ``normal_forward`` or ``fake_quant_forward``. | |||||
| """ | """ | ||||
| def fake_quant_forward(self, inp, q_dict): | def fake_quant_forward(self, inp, q_dict): | ||||
| @@ -31,9 +31,11 @@ class Observer(Module): | |||||
| A base class for Observer Module. | A base class for Observer Module. | ||||
| :param dtype: a string indicating to collect scale and zero_point of which dtype | :param dtype: a string indicating to collect scale and zero_point of which dtype | ||||
| :param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, | |||||
| instead of 1 greater. Usually True for weight and False for activation. | |||||
| """ | """ | ||||
| def __init__(self, dtype="qint8"): | |||||
| def __init__(self, dtype: str, narrow_range: bool = False): | |||||
| super().__init__() | super().__init__() | ||||
| if dtype not in _metadata_dict.keys(): | if dtype not in _metadata_dict.keys(): | ||||
| raise ValueError( | raise ValueError( | ||||
| @@ -42,7 +44,10 @@ class Observer(Module): | |||||
| ) | ) | ||||
| ) | ) | ||||
| self.dtype = dtype | self.dtype = dtype | ||||
| self.qmin = _metadata_dict[dtype].qmin | |||||
| self.narrow_range = narrow_range | |||||
| self.qmin = ( | |||||
| -_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin | |||||
| ) | |||||
| self.qmax = _metadata_dict[dtype].qmax | self.qmax = _metadata_dict[dtype].qmax | ||||
| self.enabled = True | self.enabled = True | ||||
| @@ -96,8 +101,14 @@ def create_observer_dict(mode): | |||||
| class MinMaxObserver(Observer): | class MinMaxObserver(Observer): | ||||
| def __init__(self, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"): | |||||
| super().__init__(dtype) | |||||
| def __init__( | |||||
| self, | |||||
| mode=ObserverMode.SYMMERTIC, | |||||
| eps=0.00001, | |||||
| dtype="qint8", | |||||
| narrow_range: bool = False, | |||||
| ): | |||||
| super().__init__(dtype, narrow_range) | |||||
| self.mode = mode | self.mode = mode | ||||
| self.min_val = Buffer(np.finfo(np.float32).max, dtype=np.float32) | self.min_val = Buffer(np.finfo(np.float32).max, dtype=np.float32) | ||||
| self.max_val = Buffer(np.finfo(np.float32).min, dtype=np.float32) | self.max_val = Buffer(np.finfo(np.float32).min, dtype=np.float32) | ||||
| @@ -153,9 +164,14 @@ class MinMaxObserver(Observer): | |||||
| class ExponentialMovingAverageObserver(MinMaxObserver): | class ExponentialMovingAverageObserver(MinMaxObserver): | ||||
| def __init__( | def __init__( | ||||
| self, momentum=0.9, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8" | |||||
| self, | |||||
| momentum=0.9, | |||||
| mode=ObserverMode.SYMMERTIC, | |||||
| eps=0.00001, | |||||
| dtype="qint8", | |||||
| narrow_range: bool = False, | |||||
| ): | ): | ||||
| super().__init__(mode, eps, dtype) | |||||
| super().__init__(mode, eps, dtype, narrow_range) | |||||
| self.momentum = Buffer(momentum) | self.momentum = Buffer(momentum) | ||||
| self.runtime_momentum = Buffer(0.0) | self.runtime_momentum = Buffer(0.0) | ||||
| @@ -188,11 +204,12 @@ class HistogramObserver(MinMaxObserver): | |||||
| self, | self, | ||||
| bins=2048, | bins=2048, | ||||
| upsample_rate=128, | upsample_rate=128, | ||||
| dtype="qint8", | |||||
| mode=ObserverMode.SYMMERTIC, | mode=ObserverMode.SYMMERTIC, | ||||
| eps=0.00001, | eps=0.00001, | ||||
| dtype="qint8", | |||||
| narrow_range: bool = False, | |||||
| ): | ): | ||||
| super().__init__(mode, eps, dtype) | |||||
| super().__init__(mode, eps, dtype, narrow_range) | |||||
| self.bins = bins | self.bins = bins | ||||
| self.upsample_rate = upsample_rate | self.upsample_rate = upsample_rate | ||||
| self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 | self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 | ||||
| @@ -5,6 +5,8 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from functools import partial | |||||
| from ..module import Module | from ..module import Module | ||||
| from .fake_quant import TQT, FakeQuantize | from .fake_quant import TQT, FakeQuantize | ||||
| from .observer import ( | from .observer import ( | ||||
| @@ -22,9 +24,9 @@ class QConfig: | |||||
| :param weight_observer: interface to instantiate an :class:`~.Observer` indicating | :param weight_observer: interface to instantiate an :class:`~.Observer` indicating | ||||
| how to collect scales and zero_point of wegiht. | how to collect scales and zero_point of wegiht. | ||||
| :param act_observer: similar to ``weight_observer`` but toward activation. | :param act_observer: similar to ``weight_observer`` but toward activation. | ||||
| :param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating | |||||
| how to do fake_quant calculation. can be invoked multi times to get different | |||||
| instance for each target tensor, for better control on enable and disable. | |||||
| :param weight_fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating | |||||
| how to do fake_quant calculation. | |||||
| :param act_observer: similar to ``weight_fake_quant`` but toward activation. | |||||
| Examples: | Examples: | ||||
| @@ -32,14 +34,24 @@ class QConfig: | |||||
| # Default EMA QConfig for QAT. | # Default EMA QConfig for QAT. | ||||
| ema_fakequant_qconfig = QConfig( | ema_fakequant_qconfig = QConfig( | ||||
| weight_observer=MinMaxObserver, | |||||
| act_observer=ExponentialMovingAverageObserver, | |||||
| fake_quant=FakeQuantize, | |||||
| weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||||
| act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False), | |||||
| weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||||
| act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||||
| ) | ) | ||||
| Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial`` | |||||
| to add initialization parameters of the ``class``, so that don't need to provide parameters in | |||||
| :meth:`~.QATModule.set_qconfig`. | |||||
| Usually we set ``narrow_range`` of weight related paramters to ``True`` and of activation related | |||||
| parameters to ``False``. For the result of multiplication and addition as ``a * b + c * d``, if | |||||
| four variables are all -128 of dtype ``qint8``, then the result will be ``2^15`` and cause overflow. | |||||
| Weights are commonly calculated in this way, so needed to narrow the range. | |||||
| """ | """ | ||||
| def __init__( | def __init__( | ||||
| self, act_observer, weight_observer, fake_quant, | |||||
| self, weight_observer, act_observer, weight_fake_quant, act_fake_quant | |||||
| ): | ): | ||||
| if isinstance(act_observer, Module) or isinstance(weight_observer, Module): | if isinstance(act_observer, Module) or isinstance(weight_observer, Module): | ||||
| raise ValueError( | raise ValueError( | ||||
| @@ -47,30 +59,42 @@ class QConfig: | |||||
| " class generator using `partial(Observer, ...)` instead. Use" | " class generator using `partial(Observer, ...)` instead. Use" | ||||
| " partial(MyObserver, x=1) to override arguments to constructor if needed" | " partial(MyObserver, x=1) to override arguments to constructor if needed" | ||||
| ) | ) | ||||
| self.act_observer = act_observer | |||||
| self.weight_observer = weight_observer | self.weight_observer = weight_observer | ||||
| self.fake_quant = fake_quant | |||||
| self.act_observer = act_observer | |||||
| self.weight_fake_quant = weight_fake_quant | |||||
| self.act_fake_quant = act_fake_quant | |||||
| tqt_quant_qconfig = QConfig( | tqt_quant_qconfig = QConfig( | ||||
| weight_observer=ExponentialMovingAverageObserver, | |||||
| act_observer=ExponentialMovingAverageObserver, | |||||
| fake_quant=TQT, | |||||
| weight_observer=partial( | |||||
| ExponentialMovingAverageObserver, dtype="qint8", narrow_range=True | |||||
| ), | |||||
| act_observer=partial( | |||||
| ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False | |||||
| ), | |||||
| weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), | |||||
| act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), | |||||
| ) | ) | ||||
| # Default QAT QConfigs | |||||
| min_max_fakequant_qconfig = QConfig( | min_max_fakequant_qconfig = QConfig( | ||||
| weight_observer=MinMaxObserver, | |||||
| act_observer=MinMaxObserver, | |||||
| fake_quant=FakeQuantize, | |||||
| weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||||
| act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False), | |||||
| weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||||
| act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||||
| ) | ) | ||||
| ema_fakequant_qconfig = QConfig( | ema_fakequant_qconfig = QConfig( | ||||
| weight_observer=MinMaxObserver, | |||||
| act_observer=ExponentialMovingAverageObserver, | |||||
| fake_quant=FakeQuantize, | |||||
| weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||||
| act_observer=partial( | |||||
| ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False | |||||
| ), | |||||
| weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||||
| act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||||
| ) | ) | ||||
| calibration_qconfig = QConfig( | calibration_qconfig = QConfig( | ||||
| weight_observer=MinMaxObserver, act_observer=HistogramObserver, fake_quant=None, | |||||
| weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||||
| act_observer=partial(HistogramObserver, dtype="qint8", narrow_range=False), | |||||
| weight_fake_quant=None, | |||||
| act_fake_quant=None, | |||||
| ) | ) | ||||