| @@ -154,6 +154,7 @@ class Function(metaclass=ABCMeta): | |||
| memo[id(self)] = result | |||
| for k, v in self.__dict__.items(): | |||
| setattr(result, k, copy.deepcopy(v, memo)) | |||
| setattr(result, "saved_tensors", tmp) | |||
| self.saved_tensors = tmp | |||
| return result | |||
| @@ -77,13 +77,19 @@ class QATModule(Module): | |||
| r""" | |||
| Get weight's quantization dtype as the method from ``qconfig``. | |||
| """ | |||
| return self.weight_observer.get_dtype() | |||
| if hasattr(self.act_fake_quant, "get_dtype"): | |||
| return self.weight_fake_quant.get_dtype() | |||
| else: | |||
| return self.weight_observer.get_dtype() | |||
| def get_activation_dtype(self): | |||
| r""" | |||
| Get activation's quantization dtype as the method from ``qconfig``. | |||
| """ | |||
| return self.act_observer.get_dtype() | |||
| if hasattr(self.act_fake_quant, "get_dtype"): | |||
| return self.act_fake_quant.get_dtype() | |||
| else: | |||
| return self.act_observer.get_dtype() | |||
| @classmethod | |||
| @abstractmethod | |||
| @@ -12,4 +12,5 @@ from .qconfig import ( | |||
| calibration_qconfig, | |||
| ema_fakequant_qconfig, | |||
| min_max_fakequant_qconfig, | |||
| tqt_quant_qconfig, | |||
| ) | |||
| @@ -5,17 +5,20 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import copy | |||
| import math | |||
| import numpy as np | |||
| from .. import functional as F | |||
| from .._internal.dtype import _metadata_dict | |||
| from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||
| from ..core import Buffer, Function, Parameter | |||
| from ..jit import sideeffect | |||
| from ..module import Module | |||
| from .observer import ObserverMode, Round | |||
| class FakeQuantize(Module): | |||
| r""" | |||
| A module to do quant and dequant according to observer's scale and zero_point. | |||
| """ | |||
| class _FakeQuantize(Module): | |||
| def __init__(self, dtype: str, enable: bool = True): | |||
| super().__init__() | |||
| if not dtype in _metadata_dict.keys(): | |||
| @@ -35,25 +38,103 @@ class FakeQuantize(Module): | |||
| def disable(self): | |||
| self.enabled = False | |||
| def fake_quant_forward(self, inp, q_dict): | |||
| return inp | |||
| def normal_foward(self, inp, q_dict): | |||
| return inp | |||
| def forward(self, inp, q_dict): | |||
| if self.enabled: | |||
| if q_dict["mode"] == ObserverMode.SYMMERTIC: | |||
| scale = q_dict["scale"] | |||
| # Quant | |||
| oup = Round()(inp / scale) | |||
| # clip | |||
| oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||
| # DeQuant | |||
| oup = (oup) * scale | |||
| return oup | |||
| else: | |||
| scale = q_dict["scale"] | |||
| zero_point = q_dict["zero_point"] | |||
| # Quant | |||
| oup = Round()(inp / scale) + zero_point | |||
| # clip | |||
| oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||
| # DeQuant | |||
| oup = (oup - zero_point) * scale | |||
| return oup | |||
| return self.fake_quant_forward(inp, q_dict) | |||
| else: | |||
| return self.normal_foward(inp, q_dict) | |||
| class TQT_Function(Function): | |||
| def __init__(self, lowerbound, upperbound): | |||
| super().__init__() | |||
| self.lowerbound = lowerbound | |||
| self.upperbound = upperbound | |||
| def forward(self, inp, scale): | |||
| t = 2 ** scale | |||
| # t = F.maximum(t, 1e-4) | |||
| inp_scaled = inp / t | |||
| inp_clipped = F.maximum(F.minimum(inp_scaled, self.upperbound), self.lowerbound) | |||
| inp_rounded = F.round(inp_clipped) | |||
| inp_flq = inp_rounded * t | |||
| self.save_for_backward(inp_scaled, inp_rounded, t) | |||
| return inp_flq | |||
| def backward(self, grad_inp_flq): | |||
| (inp_scaled, inp_rounded, t) = self.saved_tensors | |||
| mask_clip = (inp_scaled < -0.5 + self.lowerbound) + ( | |||
| inp_scaled > self.upperbound + 0.5 | |||
| ) # mask for accumulating the gradients of |data_scaled|>L | |||
| mask_quant = F.abs( | |||
| mask_clip - 1 | |||
| ) # mask for accumulating the gradients with |data_scaled|<=L | |||
| grad_quant = ( | |||
| grad_inp_flq * mask_quant * (inp_rounded - inp_scaled) | |||
| ) # gradient within |data_scaled|<=L | |||
| grad_clip = ( | |||
| grad_inp_flq * mask_clip * inp_rounded | |||
| ) # gradient with | data_scaled|>L | |||
| grad_s = grad_clip.sum() + grad_quant.sum() | |||
| # dL/ds = dL/dt * t * ln(2) | |||
| grad_s = grad_s * t * math.log(2) | |||
| grad_inp = grad_inp_flq * mask_quant | |||
| return grad_inp, grad_s | |||
| class TQT(_FakeQuantize): | |||
| """ | |||
| TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds | |||
| for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks | |||
| """ | |||
| def __init__(self, dtype: str, enable: bool = True): | |||
| super().__init__(dtype, enable) | |||
| self.scale = Parameter(0.0, dtype=np.float32) | |||
| def fake_quant_forward(self, inp, q_dict): | |||
| # when enable, TQT will do fakequant forward, finetune the scale | |||
| return TQT_Function(self.qmin, self.qmax)(inp, self.scale) | |||
| def normal_foward(self, inp, q_dict): | |||
| # when disable, TQT will do normal forward, initialize scale weight | |||
| tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) | |||
| tmp_scale = F.log(tmp_scale / 127) / F.log(2) | |||
| F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | |||
| return inp | |||
| def get_dtype(self): | |||
| return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None) | |||
| class FakeQuantize(_FakeQuantize): | |||
| r""" | |||
| A module to do quant and dequant according to observer's scale and zero_point. | |||
| """ | |||
| def fake_quant_forward(self, inp, q_dict): | |||
| if q_dict["mode"] == ObserverMode.SYMMERTIC: | |||
| scale = q_dict["scale"] | |||
| # Quant | |||
| oup = Round()(inp / scale) | |||
| # clip | |||
| oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||
| # DeQuant | |||
| oup = (oup) * scale | |||
| return oup | |||
| else: | |||
| scale = q_dict["scale"] | |||
| zero_point = q_dict["zero_point"] | |||
| # Quant | |||
| oup = Round()(inp / scale) + zero_point | |||
| # clip | |||
| oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||
| # DeQuant | |||
| oup = (oup - zero_point) * scale | |||
| return oup | |||
| @@ -107,6 +107,8 @@ class MinMaxObserver(Observer): | |||
| min_val = F.minimum(0.0, inp_min_val) | |||
| max_val = F.maximum(0.0, inp_max_val) | |||
| q_dict = create_observer_dict(self.mode) | |||
| q_dict["min_val"] = inp_min_val | |||
| q_dict["max_val"] = inp_max_val | |||
| if self.mode == ObserverMode.SYMMERTIC: | |||
| symmetric_max_vals = F.maximum(-min_val, max_val) | |||
| # use maximun to avoid scale too small at the begin | |||
| @@ -1,12 +1,12 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| #' | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..module import Module | |||
| from .fake_quant import FakeQuantize | |||
| from .fake_quant import TQT, FakeQuantize | |||
| from .observer import ( | |||
| ExponentialMovingAverageObserver, | |||
| HistogramObserver, | |||
| @@ -52,6 +52,12 @@ class QConfig: | |||
| self.fake_quant = fake_quant | |||
| tqt_quant_qconfig = QConfig( | |||
| weight_observer=ExponentialMovingAverageObserver, | |||
| act_observer=ExponentialMovingAverageObserver, | |||
| fake_quant=TQT, | |||
| ) | |||
| # Default QAT QConfigs | |||
| min_max_fakequant_qconfig = QConfig( | |||
| weight_observer=MinMaxObserver, | |||
| @@ -96,7 +96,6 @@ def test_deepcopy(): | |||
| origin = Sigmoid(0) | |||
| new = copy.deepcopy(Sigmoid(0)) | |||
| assert new.param == origin.param | |||
| assert new.saved_tensors == None | |||
| def test_save_context(): | |||
| @@ -0,0 +1,77 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine._internal as mgb | |||
| from megengine.core import tensor | |||
| from megengine.quantization.fake_quant import TQT_Function | |||
| from megengine.test import assertTensorClose | |||
| class numpy_TQT_Function: | |||
| def __init__(self, lowerbound, upperbound): | |||
| super().__init__() | |||
| self.lowerbound = lowerbound | |||
| self.upperbound = upperbound | |||
| def forward(self, inp, scale): | |||
| t = 2 ** scale | |||
| # t = F.maximum(t, 1e-4) | |||
| inp_scaled = inp / t | |||
| inp_clipped = np.maximum( | |||
| np.minimum(inp_scaled, self.upperbound), self.lowerbound | |||
| ) | |||
| inp_rounded = np.round(inp_clipped) | |||
| inp_flq = inp_rounded * t | |||
| self.saved_tensors = (inp_scaled, inp_rounded, t) | |||
| return inp_flq | |||
| def backward(self, grad_inp_flq): | |||
| (inp_scaled, inp_rounded, t) = self.saved_tensors | |||
| mask_clip = (inp_scaled < -0.5 + self.lowerbound) + ( | |||
| inp_scaled > self.upperbound + 0.5 | |||
| ) # mask for accumulating the gradients of |data_scaled|>L | |||
| mask_quant = np.abs( | |||
| mask_clip - 1 | |||
| ) # mask for accumulating the gradients with |data_scaled|<=L | |||
| grad_quant = ( | |||
| grad_inp_flq * mask_quant * (inp_rounded - inp_scaled) | |||
| ) # gradient within |data_scaled|<=L | |||
| grad_clip = ( | |||
| grad_inp_flq * mask_clip * inp_rounded | |||
| ) # gradient with | data_scaled|>L | |||
| grad_s = grad_clip.sum() + grad_quant.sum() | |||
| # dL/ds = dL/dt * t * ln(2) | |||
| grad_s = grad_s * t * np.log(2) | |||
| grad_inp = grad_inp_flq * mask_quant | |||
| return grad_inp, grad_s | |||
| def test_TQT(): | |||
| f = TQT_Function(-127, 127) | |||
| nf = numpy_TQT_Function(-127, 127) | |||
| def check_inp(a, b, c, a_np, b_np, c_np): | |||
| assertTensorClose( | |||
| f.forward(a, b).numpy(), nf.forward(a_np, b_np).astype("float32") | |||
| ) | |||
| c1, c2 = f.backward(c) | |||
| c1_np, c2_np = nf.backward(c_np) | |||
| assertTensorClose(c1.numpy(), c1_np.astype("float32")) | |||
| assertTensorClose(c2.numpy(), c2_np.astype("float32")) | |||
| a = tensor() | |||
| b = tensor() | |||
| a_np = np.random.random((4, 3)).astype("float32") | |||
| b_np = np.random.random((1)).astype("float32") | |||
| a.set_value(a_np) | |||
| b.set_value(b_np) | |||
| check_inp(a, b, b, a_np, b_np, b_np) | |||