GitOrigin-RevId: 060d908349
tags/v1.2.0
| @@ -17,9 +17,7 @@ from .module import QuantizedModule | |||||
| class Linear(QuantizedModule): | class Linear(QuantizedModule): | ||||
| r"""Quantized version of :class:`~.qat.linear.Linear`.""" | r"""Quantized version of :class:`~.qat.linear.Linear`.""" | ||||
| def __init__( | |||||
| self, dtype: np.dtype = None, | |||||
| ): | |||||
| def __init__(self, dtype: np.dtype = None): | |||||
| super().__init__() | super().__init__() | ||||
| self.weight = None | self.weight = None | ||||
| self.bias = None | self.bias = None | ||||
| @@ -15,7 +15,8 @@ from .qconfig import ( | |||||
| ema_fakequant_qconfig, | ema_fakequant_qconfig, | ||||
| ema_lowbit_fakequant_qconfig, | ema_lowbit_fakequant_qconfig, | ||||
| min_max_fakequant_qconfig, | min_max_fakequant_qconfig, | ||||
| passive_qconfig, | |||||
| sync_ema_fakequant_qconfig, | sync_ema_fakequant_qconfig, | ||||
| tqt_quant_qconfig, | |||||
| tqt_qconfig, | |||||
| ) | ) | ||||
| from .utils import QuantMode | from .utils import QuantMode | ||||
| @@ -28,7 +28,9 @@ class _FakeQuantize(Module): | |||||
| :param enable: whether do ``normal_forward`` or ``fake_quant_forward``. | :param enable: whether do ``normal_forward`` or ``fake_quant_forward``. | ||||
| """ | """ | ||||
| def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | |||||
| def __init__( | |||||
| self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs | |||||
| ): | |||||
| super().__init__() | super().__init__() | ||||
| if not dtype in _metadata_dict.keys(): | if not dtype in _metadata_dict.keys(): | ||||
| raise ValueError( | raise ValueError( | ||||
| @@ -114,24 +116,28 @@ class TQT(_FakeQuantize): | |||||
| for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks. | for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks. | ||||
| """ | """ | ||||
| def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | |||||
| super().__init__(dtype, narrow_range, enable) | |||||
| self.scale = Parameter([0.0], dtype=np.float32) | |||||
| def __init__( | |||||
| self, | |||||
| q_dict, | |||||
| dtype: str, | |||||
| narrow_range: bool = False, | |||||
| enable: bool = True, | |||||
| **kwargs | |||||
| ): | |||||
| super().__init__(dtype, narrow_range, enable, **kwargs) | |||||
| assert ( | |||||
| q_dict["mode"] == QuantMode.SYMMERTIC | |||||
| ), "only symmetric quantization is supported by TQT" | |||||
| if "scale" not in q_dict or q_dict["scale"] is None: | |||||
| raise AssertionError("Can not get an initialized scale") | |||||
| self.scale = F.log(q_dict["scale"]) / math.log(2) | |||||
| def fake_quant_forward(self, inp, q_dict=None): | def fake_quant_forward(self, inp, q_dict=None): | ||||
| # when enable, TQT will do fakequant forward, finetune the scale | # when enable, TQT will do fakequant forward, finetune the scale | ||||
| return TQT_Function(self.qmin, self.qmax)(inp, self.scale) | return TQT_Function(self.qmin, self.qmax)(inp, self.scale) | ||||
| def normal_foward(self, inp, q_dict=None): | |||||
| if q_dict["enable_observer"]: | |||||
| # when disable, TQT will do normal forward, initialize scale weight | |||||
| tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) | |||||
| tmp_scale = F.log(tmp_scale / 127) / math.log(2) | |||||
| self.scale[...] = tmp_scale | |||||
| return inp | |||||
| def get_qparams(self): | def get_qparams(self): | ||||
| q_dict = get_qparam_dict(QuantMode.TQT) | |||||
| q_dict = get_qparam_dict(QuantMode.SYMMERTIC) | |||||
| q_dict["scale"] = 2 ** self.scale | q_dict["scale"] = 2 ** self.scale | ||||
| return q_dict | return q_dict | ||||
| @@ -7,6 +7,7 @@ | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import math | import math | ||||
| from abc import abstractmethod | from abc import abstractmethod | ||||
| from copy import deepcopy | |||||
| import numpy as np | import numpy as np | ||||
| @@ -28,7 +29,7 @@ class Observer(Module): | |||||
| instead of 1 greater. Usually True for weight and False for activation. | instead of 1 greater. Usually True for weight and False for activation. | ||||
| """ | """ | ||||
| def __init__(self, dtype: str, narrow_range: bool = False): | |||||
| def __init__(self, dtype: str, narrow_range: bool = False, **kwargs): | |||||
| super().__init__() | super().__init__() | ||||
| if dtype not in _metadata_dict.keys(): | if dtype not in _metadata_dict.keys(): | ||||
| raise ValueError( | raise ValueError( | ||||
| @@ -81,8 +82,9 @@ class MinMaxObserver(Observer): | |||||
| eps=0.00001, | eps=0.00001, | ||||
| dtype="qint8", | dtype="qint8", | ||||
| narrow_range: bool = False, | narrow_range: bool = False, | ||||
| **kwargs | |||||
| ): | ): | ||||
| super().__init__(dtype, narrow_range) | |||||
| super().__init__(dtype, narrow_range, **kwargs) | |||||
| self.mode = mode | self.mode = mode | ||||
| self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32) | self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32) | ||||
| self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32) | self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32) | ||||
| @@ -105,7 +107,7 @@ class MinMaxObserver(Observer): | |||||
| else: | else: | ||||
| # use maximun to avoid scale too small at the begin | # use maximun to avoid scale too small at the begin | ||||
| q_dict["scale"] = F.maximum( | q_dict["scale"] = F.maximum( | ||||
| (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit, | |||||
| (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit | |||||
| ) | ) | ||||
| # caculate zero_point | # caculate zero_point | ||||
| q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"])) | q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"])) | ||||
| @@ -148,8 +150,9 @@ class ExponentialMovingAverageObserver(MinMaxObserver): | |||||
| eps=0.00001, | eps=0.00001, | ||||
| dtype="qint8", | dtype="qint8", | ||||
| narrow_range: bool = False, | narrow_range: bool = False, | ||||
| **kwargs | |||||
| ): | ): | ||||
| super().__init__(mode, eps, dtype, narrow_range) | |||||
| super().__init__(mode, eps, dtype, narrow_range, **kwargs) | |||||
| self.momentum = Tensor(momentum) | self.momentum = Tensor(momentum) | ||||
| self.runtime_momentum = Tensor(0.0) | self.runtime_momentum = Tensor(0.0) | ||||
| @@ -205,8 +208,9 @@ class HistogramObserver(MinMaxObserver): | |||||
| eps=0.00001, | eps=0.00001, | ||||
| dtype="qint8", | dtype="qint8", | ||||
| narrow_range: bool = False, | narrow_range: bool = False, | ||||
| **kwargs | |||||
| ): | ): | ||||
| super().__init__(mode, eps, dtype, narrow_range) | |||||
| super().__init__(mode, eps, dtype, narrow_range, **kwargs) | |||||
| self.bins = bins | self.bins = bins | ||||
| self.upsample_rate = upsample_rate | self.upsample_rate = upsample_rate | ||||
| self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 | self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 | ||||
| @@ -417,7 +421,7 @@ class HistogramObserver(MinMaxObserver): | |||||
| # combine the existing histogram and new histogram into 1 histogram | # combine the existing histogram and new histogram into 1 histogram | ||||
| # We do this by first upsampling the histogram to a dense grid | # We do this by first upsampling the histogram to a dense grid | ||||
| # and then downsampling the histogram efficiently | # and then downsampling the histogram efficiently | ||||
| (new_min, new_max, downsample_rate, start_idx,) = self._adjust_min_max( | |||||
| (new_min, new_max, downsample_rate, start_idx) = self._adjust_min_max( | |||||
| new_min, new_max, self.upsample_rate | new_min, new_max, self.upsample_rate | ||||
| ) | ) | ||||
| @@ -442,3 +446,34 @@ class HistogramObserver(MinMaxObserver): | |||||
| def forward(self, x_orig): | def forward(self, x_orig): | ||||
| self.sideeffect_forward(x_orig) | self.sideeffect_forward(x_orig) | ||||
| return x_orig | return x_orig | ||||
| class PassiveObserver(Observer): | |||||
| r""" | |||||
| This class can be set :attr:`scale` derectly. | |||||
| """ | |||||
| def __init__(self, q_dict, dtype: str, narrow_range: bool = False, **kwargs): | |||||
| super().__init__(dtype, narrow_range, **kwargs) | |||||
| self.q_dict = deepcopy(q_dict) | |||||
| if "scale" not in q_dict or q_dict["scale"] is None: | |||||
| raise AssertionError("Can not get an initialized scale") | |||||
| self.orig_scale = q_dict["scale"].numpy() | |||||
| @property | |||||
| def scale(self): | |||||
| return self.q_dict["scale"] | |||||
| @scale.setter | |||||
| def scale(self, value): | |||||
| assert value > 0 | |||||
| self.q_dict["scale"].set_value(value) | |||||
| def get_qparams(self): | |||||
| return self.q_dict | |||||
| def forward(self, x): | |||||
| r""" | |||||
| Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`. | |||||
| """ | |||||
| return x | |||||
| @@ -13,6 +13,7 @@ from .observer import ( | |||||
| ExponentialMovingAverageObserver, | ExponentialMovingAverageObserver, | ||||
| HistogramObserver, | HistogramObserver, | ||||
| MinMaxObserver, | MinMaxObserver, | ||||
| PassiveObserver, | |||||
| SyncExponentialMovingAverageObserver, | SyncExponentialMovingAverageObserver, | ||||
| SyncMinMaxObserver, | SyncMinMaxObserver, | ||||
| ) | ) | ||||
| @@ -66,17 +67,22 @@ class QConfig: | |||||
| self.weight_fake_quant = weight_fake_quant | self.weight_fake_quant = weight_fake_quant | ||||
| self.act_fake_quant = act_fake_quant | self.act_fake_quant = act_fake_quant | ||||
| def __eq__(self, other): | |||||
| def eq(a, b): | |||||
| if isinstance(a, partial) and isinstance(b, partial): | |||||
| return all( | |||||
| [a.func == b.func, a.args == b.args, a.keywords == b.keywords] | |||||
| ) | |||||
| else: | |||||
| return a == b | |||||
| return ( | |||||
| eq(self.weight_observer, other.weight_observer) | |||||
| and eq(self.act_observer, other.act_observer) | |||||
| and eq(self.weight_fake_quant, other.weight_fake_quant) | |||||
| and eq(self.act_fake_quant, other.act_fake_quant) | |||||
| ) | |||||
| tqt_quant_qconfig = QConfig( | |||||
| weight_observer=partial( | |||||
| ExponentialMovingAverageObserver, dtype="qint8", narrow_range=True | |||||
| ), | |||||
| act_observer=partial( | |||||
| ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False | |||||
| ), | |||||
| weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), | |||||
| act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), | |||||
| ) | |||||
| min_max_fakequant_qconfig = QConfig( | min_max_fakequant_qconfig = QConfig( | ||||
| weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | ||||
| @@ -118,3 +124,17 @@ calibration_qconfig = QConfig( | |||||
| weight_fake_quant=None, | weight_fake_quant=None, | ||||
| act_fake_quant=None, | act_fake_quant=None, | ||||
| ) | ) | ||||
| tqt_qconfig = QConfig( | |||||
| weight_observer=None, | |||||
| act_observer=None, | |||||
| weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), | |||||
| act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), | |||||
| ) | |||||
| passive_qconfig = QConfig( | |||||
| weight_observer=partial(PassiveObserver, dtype="qint8", narrow_range=True), | |||||
| act_observer=partial(PassiveObserver, dtype="qint8", narrow_range=False), | |||||
| weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||||
| act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||||
| ) | |||||
| @@ -6,15 +6,18 @@ | |||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from copy import copy, deepcopy | from copy import copy, deepcopy | ||||
| from functools import partial | |||||
| from typing import Callable, Dict, Tuple | from typing import Callable, Dict, Tuple | ||||
| import numpy as np | |||||
| from .. import module as Float | from .. import module as Float | ||||
| from ..functional import concat, norm | |||||
| from ..module import Module | from ..module import Module | ||||
| from ..module import qat as QAT | from ..module import qat as QAT | ||||
| from ..module import quantized as Quantized | from ..module import quantized as Quantized | ||||
| from ..module.qat import QATModule | from ..module.qat import QATModule | ||||
| from ..module.quantized import QuantizedModule | from ..module.quantized import QuantizedModule | ||||
| from .fake_quant import TQT | |||||
| from .qconfig import QConfig, ema_fakequant_qconfig | from .qconfig import QConfig, ema_fakequant_qconfig | ||||
| @@ -32,9 +35,7 @@ def _get_quantable_module_names(): | |||||
| return quantable_module_names | return quantable_module_names | ||||
| def _get_convert_dict() -> Tuple[ | |||||
| Dict[Module, QATModule], Dict[QATModule, QuantizedModule] | |||||
| ]: | |||||
| def _get_convert_dict(): | |||||
| quantable_module_names = _get_quantable_module_names() | quantable_module_names = _get_quantable_module_names() | ||||
| quantable_modules = [getattr(Float, key) for key in quantable_module_names] | quantable_modules = [getattr(Float, key) for key in quantable_module_names] | ||||
| @@ -47,6 +48,11 @@ def _get_convert_dict() -> Tuple[ | |||||
| _float2qat_dict, _qat2quantized_dict = _get_convert_dict() | _float2qat_dict, _qat2quantized_dict = _get_convert_dict() | ||||
| qat_modules = tuple(_qat2quantized_dict.keys()) | |||||
| def is_qat(mod: Module): | |||||
| return isinstance(mod, qat_modules) | |||||
| def quantize(module: Module, inplace: bool = True, mapping: dict = None): | def quantize(module: Module, inplace: bool = True, mapping: dict = None): | ||||
| @@ -133,6 +139,34 @@ def quantize_qat( | |||||
| return module | return module | ||||
| def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True): | |||||
| r""" | |||||
| Reset :class:`~._FakeQuantize` and :class:`~.Observer` according to ``qconfig`` | |||||
| :param module: root module to reset recursively. | |||||
| :param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig. | |||||
| :param inplace: whether to reset submodules in-place. | |||||
| """ | |||||
| if not inplace: | |||||
| module = deepcopy(module) | |||||
| def safe_call(func, q_dict): | |||||
| return func(q_dict=q_dict) if func is not None else None | |||||
| for m in list(module._flatten(predicate=is_qat)): | |||||
| if m.with_weight: | |||||
| weight_q_dict = m.get_weight_qparams() | |||||
| m.weight_observer = safe_call(qconfig.weight_observer, weight_q_dict) | |||||
| m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_q_dict) | |||||
| if m.with_act: | |||||
| act_q_dict = m.get_activation_qparams() | |||||
| m.act_observer = safe_call(qconfig.act_observer, act_q_dict) | |||||
| m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_q_dict) | |||||
| return module | |||||
| def _propagate(module: Module, func_str: str, *args, **kargs): | def _propagate(module: Module, func_str: str, *args, **kargs): | ||||
| def fn(mod: Module): | def fn(mod: Module): | ||||
| if isinstance(mod, QATModule): | if isinstance(mod, QATModule): | ||||
| @@ -151,6 +185,85 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig): | |||||
| _propagate(module, "set_qconfig", qconfig) | _propagate(module, "set_qconfig", qconfig) | ||||
| def hook_qat_module(module: Module, func: Callable): | |||||
| r""" | |||||
| Add hooks for all :class:`~.QATModule` submodule | |||||
| """ | |||||
| hooks = [] | |||||
| for submodule in list(module._flatten(predicate=is_qat)): | |||||
| hooks.append(submodule.register_forward_hook(func)) | |||||
| return hooks | |||||
| def apply_easy_quant(module, data, start=0.8, stop=1.2, num=40): | |||||
| r""" | |||||
| Implementation of ``EasyQuant``: https://arxiv.org/pdf/2006.16669. | |||||
| Search for optimal scales. | |||||
| :param module: root module. | |||||
| :param data: input tensor used to search optimal scale. | |||||
| :param start: lower bound of the search interval. | |||||
| :param stop: upper bound of the search interval. | |||||
| :param num: number of samples to search. | |||||
| """ | |||||
| batch_size = data.shape[0] | |||||
| def get_cosine(x, y): | |||||
| ndim = len(x.shape) | |||||
| axis = tuple(range(1, ndim)) | |||||
| up = (x * y).sum(axis=axis) | |||||
| down = norm(x, axis=axis) * norm(y, axis=axis) | |||||
| sim = up / down | |||||
| return sim.mean(axis=0) | |||||
| def search(mod, inputs, outputs, where): | |||||
| mod._forward_hooks.clear() | |||||
| fp32_in = [_[:batch_size] for _ in inputs] | |||||
| int8_in = [_[batch_size:] for _ in inputs] | |||||
| disable_fake_quant(mod) | |||||
| fp32_out = mod(*fp32_in) | |||||
| enable_fake_quant(mod) | |||||
| ob = getattr(mod, where) | |||||
| if ob is None: | |||||
| return | |||||
| orig_scale = ob.orig_scale | |||||
| distance = 0 | |||||
| best_scale = 0 | |||||
| for scale in np.linspace(start * orig_scale, stop * orig_scale, num): | |||||
| ob.scale = scale | |||||
| int8_out = mod(*int8_in) | |||||
| dis = get_cosine(fp32_out, int8_out) | |||||
| if dis > distance: | |||||
| distance = dis | |||||
| best_scale = scale | |||||
| ob.scale = best_scale | |||||
| if where == "act_observer": | |||||
| int8_out = mod(*int8_in) | |||||
| return concat([fp32_out, int8_out]) | |||||
| else: | |||||
| int8_out = outputs[batch_size:] | |||||
| return concat([fp32_out, int8_out]) | |||||
| data = concat([data, data]) | |||||
| hook_qat_module(module, partial(search, where="weight_observer")) | |||||
| module(data) | |||||
| hook_qat_module(module, partial(search, where="act_observer")) | |||||
| module(data) | |||||
| return module | |||||
| def disable_fake_quant(module: Module): | def disable_fake_quant(module: Module): | ||||
| r""" | r""" | ||||
| Recursively disable ``module`` fake quantization in QATModule through :meth:`~.Module.apply` | Recursively disable ``module`` fake quantization in QATModule through :meth:`~.Module.apply` | ||||
| @@ -54,17 +54,15 @@ class QuantMode(Enum): | |||||
| SYMMERTIC = 1 | SYMMERTIC = 1 | ||||
| ASYMMERTIC = 2 | ASYMMERTIC = 2 | ||||
| TQT = 3 | |||||
| qparam_dict = { | qparam_dict = { | ||||
| QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None,}, | |||||
| QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None}, | |||||
| QuantMode.ASYMMERTIC: { | QuantMode.ASYMMERTIC: { | ||||
| "mode": QuantMode.ASYMMERTIC, | "mode": QuantMode.ASYMMERTIC, | ||||
| "scale": None, | "scale": None, | ||||
| "zero_point": None, | "zero_point": None, | ||||
| }, | }, | ||||
| QuantMode.TQT: {"mode": QuantMode.TQT, "scale": None,}, | |||||
| } | } | ||||
| @@ -1,116 +0,0 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import numpy as np | |||||
| import pytest | |||||
| from megengine import module as Float | |||||
| from megengine import tensor | |||||
| from megengine.module import qat as QAT | |||||
| from megengine.quantization import min_max_fakequant_qconfig | |||||
| from megengine.quantization.quantize import ( | |||||
| _get_quantable_module_names, | |||||
| disable_fake_quant, | |||||
| quantize_qat, | |||||
| ) | |||||
| def test_get_quantable_module_names(): | |||||
| # need to make sure names from Quantized and QAT are the same | |||||
| def _get_qat_module_names(): | |||||
| def is_qat(key: str): | |||||
| value = getattr(QAT, key) | |||||
| return ( | |||||
| isinstance(value, type) | |||||
| and issubclass(value, QAT.QATModule) | |||||
| and value != QAT.QATModule | |||||
| ) | |||||
| # source should have all quantable modules' names | |||||
| quantable_module_names = [key for key in dir(QAT) if is_qat(key)] | |||||
| return quantable_module_names | |||||
| qat_module_names = _get_qat_module_names() | |||||
| quantized_module_names = _get_quantable_module_names() | |||||
| assert set(qat_module_names) == set(quantized_module_names) | |||||
| for key in qat_module_names: | |||||
| value = getattr(Float, key) | |||||
| assert ( | |||||
| isinstance(value, type) | |||||
| and issubclass(value, Float.Module) | |||||
| and value != Float.Module | |||||
| ) | |||||
| def test_disable_quantize(): | |||||
| class Net(Float.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.conv = Float.ConvBnRelu2d(3, 3, 3) | |||||
| self.conv.disable_quantize() | |||||
| def forward(self, x): | |||||
| return self.conv(x) | |||||
| net = Net() | |||||
| qat_net = quantize_qat(net, inplace=False) | |||||
| assert isinstance(qat_net.conv, Float.ConvBnRelu2d) | |||||
| assert isinstance(qat_net.conv.conv, Float.Conv2d) | |||||
| def test_convert_with_custom_mapping(): | |||||
| class FloatExample(Float.Module): | |||||
| def forward(self, x): | |||||
| return x | |||||
| class QATExample(QAT.QATModule): | |||||
| def forward(self, x): | |||||
| return x | |||||
| @classmethod | |||||
| def from_float_module(cls, float_module): | |||||
| return cls() | |||||
| class Net(Float.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.example = FloatExample() | |||||
| def forward(self, x): | |||||
| return self.example(x) | |||||
| net = Net() | |||||
| qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) | |||||
| assert isinstance(qat_net.example, QATExample) | |||||
| def test_disable_fake_quant(): | |||||
| class Net(Float.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.quant = Float.QuantStub() | |||||
| self.linear = Float.Linear(3, 3) | |||||
| self.dequant = Float.DequantStub() | |||||
| self.linear.bias.set_value(np.random.rand(3)) | |||||
| def forward(self, x): | |||||
| x = self.quant(x) | |||||
| x = self.linear(x) | |||||
| x = self.dequant(x) | |||||
| return x | |||||
| x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) | |||||
| net = Net() | |||||
| y1 = net(x).numpy() | |||||
| net = quantize_qat(net, min_max_fakequant_qconfig) | |||||
| y2 = net(x).numpy() | |||||
| disable_fake_quant(net) | |||||
| y3 = net(x).numpy() | |||||
| np.testing.assert_allclose(y1, y3) | |||||
| with pytest.raises(AssertionError): | |||||
| np.testing.assert_allclose(y2, y3) | |||||
| @@ -6,17 +6,53 @@ import pytest | |||||
| import megengine as mge | import megengine as mge | ||||
| import megengine.distributed as dist | import megengine.distributed as dist | ||||
| import megengine.quantization.observer as ob | |||||
| from megengine.distributed.helper import get_device_count_by_fork | from megengine.distributed.helper import get_device_count_by_fork | ||||
| from megengine.quantization.observer import ( | |||||
| ExponentialMovingAverageObserver, | |||||
| MinMaxObserver, | |||||
| Observer, | |||||
| PassiveObserver, | |||||
| SyncExponentialMovingAverageObserver, | |||||
| SyncMinMaxObserver, | |||||
| ) | |||||
| def test_observer(): | |||||
| with pytest.raises(TypeError): | |||||
| Observer("qint8") | |||||
| def test_min_max_observer(): | def test_min_max_observer(): | ||||
| x = np.random.rand(3, 3, 3, 3).astype("float32") | x = np.random.rand(3, 3, 3, 3).astype("float32") | ||||
| np_min, np_max = x.min(), x.max() | np_min, np_max = x.min(), x.max() | ||||
| x = mge.tensor(x) | x = mge.tensor(x) | ||||
| m = ob.MinMaxObserver() | |||||
| m = MinMaxObserver() | |||||
| m(x) | m(x) | ||||
| assert m.min_val == np_min and m.max_val == np_max | |||||
| np.testing.assert_allclose(m.min_val.numpy(), np_min) | |||||
| np.testing.assert_allclose(m.max_val.numpy(), np_max) | |||||
| def test_exponential_moving_average_observer(): | |||||
| t = np.random.rand() | |||||
| x1 = np.random.rand(3, 3, 3, 3).astype("float32") | |||||
| x2 = np.random.rand(3, 3, 3, 3).astype("float32") | |||||
| expected_min = x1.min() * t + x2.min() * (1 - t) | |||||
| expected_max = x1.max() * t + x2.max() * (1 - t) | |||||
| m = ExponentialMovingAverageObserver(momentum=t) | |||||
| m(mge.tensor(x1, dtype=np.float32)) | |||||
| m(mge.tensor(x2, dtype=np.float32)) | |||||
| np.testing.assert_allclose(m.min_val.numpy(), expected_min) | |||||
| np.testing.assert_allclose(m.max_val.numpy(), expected_max) | |||||
| def test_passive_observer(): | |||||
| q_dict = {"scale": mge.tensor(1.0)} | |||||
| m = PassiveObserver(q_dict, "qint8") | |||||
| assert m.orig_scale == 1.0 | |||||
| assert m.scale == 1.0 | |||||
| m.scale = 2.0 | |||||
| assert m.scale == 2.0 | |||||
| assert m.get_qparams() == {"scale": mge.tensor(2.0)} | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| @@ -35,9 +71,39 @@ def test_sync_min_max_observer(): | |||||
| @dist.launcher | @dist.launcher | ||||
| def worker(): | def worker(): | ||||
| rank = dist.get_rank() | rank = dist.get_rank() | ||||
| m = ob.SyncMinMaxObserver() | |||||
| m = SyncMinMaxObserver() | |||||
| y = mge.tensor(x[rank * 3 : (rank + 1) * 3]) | y = mge.tensor(x[rank * 3 : (rank + 1) * 3]) | ||||
| m(y) | m(y) | ||||
| assert m.min_val == np_min and m.max_val == np_max | assert m.min_val == np_min and m.max_val == np_max | ||||
| worker() | worker() | ||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||||
| ) | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" | |||||
| ) | |||||
| @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||||
| @pytest.mark.isolated_distributed | |||||
| def test_sync_exponential_moving_average_observer(): | |||||
| word_size = get_device_count_by_fork("gpu") | |||||
| t = np.random.rand() | |||||
| x1 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") | |||||
| x2 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") | |||||
| expected_min = x1.min() * t + x2.min() * (1 - t) | |||||
| expected_max = x1.max() * t + x2.max() * (1 - t) | |||||
| @dist.launcher | |||||
| def worker(): | |||||
| rank = dist.get_rank() | |||||
| m = SyncExponentialMovingAverageObserver(momentum=t) | |||||
| y1 = mge.tensor(x1[rank * 3 : (rank + 1) * 3]) | |||||
| y2 = mge.tensor(x2[rank * 3 : (rank + 1) * 3]) | |||||
| m(y1) | |||||
| m(y2) | |||||
| np.testing.assert_allclose(m.min_val.numpy(), expected_min) | |||||
| np.testing.assert_allclose(m.max_val.numpy(), expected_max) | |||||
| worker() | |||||
| @@ -0,0 +1,14 @@ | |||||
| from functools import partial | |||||
| from megengine.quantization import QConfig, tqt_qconfig | |||||
| from megengine.quantization.fake_quant import TQT | |||||
| def test_equal(): | |||||
| qconfig = QConfig( | |||||
| weight_observer=None, | |||||
| act_observer=None, | |||||
| weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), | |||||
| act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), | |||||
| ) | |||||
| assert qconfig == tqt_qconfig | |||||
| @@ -0,0 +1,266 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import numpy as np | |||||
| import pytest | |||||
| from megengine import functional | |||||
| from megengine import module as Float | |||||
| from megengine import tensor | |||||
| from megengine.module import qat as QAT | |||||
| from megengine.module import quantized as Q | |||||
| from megengine.quantization import ( | |||||
| min_max_fakequant_qconfig, | |||||
| passive_qconfig, | |||||
| tqt_qconfig, | |||||
| ) | |||||
| from megengine.quantization.fake_quant import TQT, FakeQuantize | |||||
| from megengine.quantization.observer import MinMaxObserver, PassiveObserver | |||||
| from megengine.quantization.quantize import ( | |||||
| _get_quantable_module_names, | |||||
| apply_easy_quant, | |||||
| disable_fake_quant, | |||||
| disable_observer, | |||||
| enable_fake_quant, | |||||
| enable_observer, | |||||
| propagate_qconfig, | |||||
| quantize, | |||||
| quantize_qat, | |||||
| reset_qconfig, | |||||
| ) | |||||
| class Net(Float.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.quant = Float.QuantStub() | |||||
| self.linear = Float.Linear(3, 3) | |||||
| self.dequant = Float.DequantStub() | |||||
| self.linear.bias.set_value(np.random.rand(3)) | |||||
| def forward(self, x): | |||||
| x = self.quant(x) | |||||
| x = self.linear(x) | |||||
| x = self.dequant(x) | |||||
| return x | |||||
| class QATNet(Float.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.quant = QAT.QuantStub() | |||||
| self.linear = QAT.Linear(3, 3) | |||||
| self.dequant = QAT.DequantStub() | |||||
| self.linear.bias.set_value(np.random.rand(3)) | |||||
| def forward(self, x): | |||||
| x = self.quant(x) | |||||
| x = self.linear(x) | |||||
| x = self.dequant(x) | |||||
| return x | |||||
| def test_propagate_qconfig(): | |||||
| net = QATNet() | |||||
| propagate_qconfig(net, min_max_fakequant_qconfig) | |||||
| assert all( | |||||
| [ | |||||
| net.quant.weight_observer is None, | |||||
| net.quant.weight_fake_quant is None, | |||||
| isinstance(net.quant.act_observer, MinMaxObserver), | |||||
| isinstance(net.quant.act_fake_quant, FakeQuantize), | |||||
| isinstance(net.linear.weight_observer, MinMaxObserver), | |||||
| isinstance(net.linear.weight_fake_quant, FakeQuantize), | |||||
| isinstance(net.linear.act_observer, MinMaxObserver), | |||||
| isinstance(net.linear.act_fake_quant, FakeQuantize), | |||||
| net.dequant.weight_observer is None, | |||||
| net.dequant.weight_fake_quant is None, | |||||
| net.dequant.act_observer is None, | |||||
| net.dequant.act_observer is None, | |||||
| ] | |||||
| ) | |||||
| def init_qat_net(): | |||||
| net = QATNet() | |||||
| propagate_qconfig(net, min_max_fakequant_qconfig) | |||||
| min_val = np.random.randint(-127, 0, size=(2,)) | |||||
| max_val = np.random.randint(1, 127, size=(2,)) | |||||
| net.linear.weight_observer.min_val.set_value(min_val[0]) | |||||
| net.linear.weight_observer.max_val.set_value(max_val[0]) | |||||
| net.linear.act_observer.min_val.set_value(min_val[1]) | |||||
| net.linear.act_observer.max_val.set_value(max_val[1]) | |||||
| return net | |||||
| def test_reset_qconfig(): | |||||
| qat_net = init_qat_net() | |||||
| new_qat_net = reset_qconfig(qat_net, passive_qconfig) | |||||
| assert ( | |||||
| new_qat_net.linear.get_weight_qparams() == qat_net.linear.get_weight_qparams() | |||||
| ) | |||||
| assert ( | |||||
| new_qat_net.linear.get_activation_qparams() | |||||
| == qat_net.linear.get_activation_qparams() | |||||
| ) | |||||
| def test_enable_and_disable_observer(): | |||||
| net = init_qat_net() | |||||
| enable_observer(net) | |||||
| assert net.quant.act_observer.enabled == True | |||||
| assert net.linear.weight_observer.enabled == True | |||||
| assert net.linear.act_observer.enabled == True | |||||
| disable_observer(net) | |||||
| assert net.quant.act_observer.enabled == False | |||||
| assert net.linear.weight_observer.enabled == False | |||||
| assert net.linear.act_observer.enabled == False | |||||
| def test_enable_and_disable_fake_quant(): | |||||
| net = init_qat_net() | |||||
| disable_fake_quant(net) | |||||
| assert net.quant.act_fake_quant.enabled == False | |||||
| assert net.linear.weight_fake_quant.enabled == False | |||||
| assert net.linear.act_fake_quant.enabled == False | |||||
| enable_fake_quant(net) | |||||
| assert net.quant.act_fake_quant.enabled == True | |||||
| assert net.linear.weight_fake_quant.enabled == True | |||||
| assert net.linear.act_fake_quant.enabled == True | |||||
| def init_observer(module, data): | |||||
| enable_observer(module) | |||||
| disable_fake_quant(module) | |||||
| module(data) | |||||
| disable_observer(module) | |||||
| enable_fake_quant(module) | |||||
| def test_enable_and_disable_all(): | |||||
| x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) | |||||
| net = Net() | |||||
| y1 = net(x).numpy() | |||||
| net = quantize_qat(net, min_max_fakequant_qconfig) | |||||
| init_observer(net, x) | |||||
| y2 = net(x).numpy() | |||||
| disable_fake_quant(net) | |||||
| y3 = net(x).numpy() | |||||
| enable_fake_quant(net) | |||||
| y4 = net(x).numpy() | |||||
| np.testing.assert_allclose(y1, y3) | |||||
| np.testing.assert_allclose(y2, y4) | |||||
| with pytest.raises(AssertionError): | |||||
| np.testing.assert_allclose(y2, y3) | |||||
| def test_quantize_qat(): | |||||
| net = Net() | |||||
| qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig) | |||||
| assert isinstance(qat_net.quant, QAT.QuantStub) | |||||
| assert isinstance(qat_net.linear, QAT.Linear) | |||||
| assert isinstance(qat_net.dequant, QAT.DequantStub) | |||||
| def test_quantize(): | |||||
| qat_net = init_qat_net() | |||||
| q_net = quantize(qat_net, inplace=False) | |||||
| assert isinstance(q_net.quant, Q.QuantStub) | |||||
| assert isinstance(q_net.linear, Q.Linear) | |||||
| assert isinstance(q_net.dequant, Q.DequantStub) | |||||
| def test_apply_easy_quant(): | |||||
| qat_net = init_qat_net() | |||||
| data = tensor(np.random.rand(2, 3, 3, 3), dtype=np.float32) | |||||
| eq_net = reset_qconfig(qat_net, passive_qconfig, inplace=False) | |||||
| apply_easy_quant(eq_net, data, 0.9, 1.1, 10) | |||||
| assert isinstance(eq_net.quant.act_observer, PassiveObserver) | |||||
| assert isinstance(eq_net.linear.weight_observer, PassiveObserver) | |||||
| assert isinstance(eq_net.linear.act_observer, PassiveObserver) | |||||
| assert eq_net.dequant.act_observer is None | |||||
| def test_apply_tqt(): | |||||
| qat_net = init_qat_net() | |||||
| tqt_net = reset_qconfig(qat_net, tqt_qconfig, inplace=False) | |||||
| assert isinstance(tqt_net.quant.act_fake_quant, TQT) | |||||
| assert isinstance(tqt_net.linear.weight_fake_quant, TQT) | |||||
| assert isinstance(tqt_net.linear.act_fake_quant, TQT) | |||||
| assert tqt_net.dequant.act_fake_quant is None | |||||
| def test_get_quantable_module_names(): | |||||
| # need to make sure names from Quantized and QAT are the same | |||||
| def _get_qat_module_names(): | |||||
| def is_qat(key: str): | |||||
| value = getattr(QAT, key) | |||||
| return ( | |||||
| isinstance(value, type) | |||||
| and issubclass(value, QAT.QATModule) | |||||
| and value != QAT.QATModule | |||||
| ) | |||||
| # source should have all quantable modules' names | |||||
| quantable_module_names = [key for key in dir(QAT) if is_qat(key)] | |||||
| return quantable_module_names | |||||
| qat_module_names = _get_qat_module_names() | |||||
| quantized_module_names = _get_quantable_module_names() | |||||
| assert set(qat_module_names) == set(quantized_module_names) | |||||
| for key in qat_module_names: | |||||
| value = getattr(Float, key) | |||||
| assert ( | |||||
| isinstance(value, type) | |||||
| and issubclass(value, Float.Module) | |||||
| and value != Float.Module | |||||
| ) | |||||
| def test_disable_quantize(): | |||||
| class Net(Float.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.conv = Float.ConvBnRelu2d(3, 3, 3) | |||||
| self.conv.disable_quantize() | |||||
| def forward(self, x): | |||||
| return self.conv(x) | |||||
| net = Net() | |||||
| qat_net = quantize_qat(net, inplace=False) | |||||
| assert isinstance(qat_net.conv, Float.ConvBnRelu2d) | |||||
| assert isinstance(qat_net.conv.conv, Float.Conv2d) | |||||
| def test_convert_with_custom_mapping(): | |||||
| class FloatExample(Float.Module): | |||||
| def forward(self, x): | |||||
| return x | |||||
| class QATExample(QAT.QATModule): | |||||
| def forward(self, x): | |||||
| return x | |||||
| @classmethod | |||||
| def from_float_module(cls, float_module): | |||||
| return cls() | |||||
| class Net(Float.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.example = FloatExample() | |||||
| def forward(self, x): | |||||
| return self.example(x) | |||||
| net = Net() | |||||
| qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) | |||||
| assert isinstance(qat_net.example, QATExample) | |||||