| @@ -154,6 +154,7 @@ class Function(metaclass=ABCMeta): | |||||
| memo[id(self)] = result | memo[id(self)] = result | ||||
| for k, v in self.__dict__.items(): | for k, v in self.__dict__.items(): | ||||
| setattr(result, k, copy.deepcopy(v, memo)) | setattr(result, k, copy.deepcopy(v, memo)) | ||||
| setattr(result, "saved_tensors", tmp) | |||||
| self.saved_tensors = tmp | self.saved_tensors = tmp | ||||
| return result | return result | ||||
| @@ -77,13 +77,19 @@ class QATModule(Module): | |||||
| r""" | r""" | ||||
| Get weight's quantization dtype as the method from ``qconfig``. | Get weight's quantization dtype as the method from ``qconfig``. | ||||
| """ | """ | ||||
| return self.weight_observer.get_dtype() | |||||
| if hasattr(self.act_fake_quant, "get_dtype"): | |||||
| return self.weight_fake_quant.get_dtype() | |||||
| else: | |||||
| return self.weight_observer.get_dtype() | |||||
| def get_activation_dtype(self): | def get_activation_dtype(self): | ||||
| r""" | r""" | ||||
| Get activation's quantization dtype as the method from ``qconfig``. | Get activation's quantization dtype as the method from ``qconfig``. | ||||
| """ | """ | ||||
| return self.act_observer.get_dtype() | |||||
| if hasattr(self.act_fake_quant, "get_dtype"): | |||||
| return self.act_fake_quant.get_dtype() | |||||
| else: | |||||
| return self.act_observer.get_dtype() | |||||
| @classmethod | @classmethod | ||||
| @abstractmethod | @abstractmethod | ||||
| @@ -12,4 +12,5 @@ from .qconfig import ( | |||||
| calibration_qconfig, | calibration_qconfig, | ||||
| ema_fakequant_qconfig, | ema_fakequant_qconfig, | ||||
| min_max_fakequant_qconfig, | min_max_fakequant_qconfig, | ||||
| tqt_quant_qconfig, | |||||
| ) | ) | ||||
| @@ -5,17 +5,20 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| import copy | |||||
| import math | |||||
| import numpy as np | |||||
| from .. import functional as F | from .. import functional as F | ||||
| from .._internal.dtype import _metadata_dict | |||||
| from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||||
| from ..core import Buffer, Function, Parameter | |||||
| from ..jit import sideeffect | |||||
| from ..module import Module | from ..module import Module | ||||
| from .observer import ObserverMode, Round | from .observer import ObserverMode, Round | ||||
| class FakeQuantize(Module): | |||||
| r""" | |||||
| A module to do quant and dequant according to observer's scale and zero_point. | |||||
| """ | |||||
| class _FakeQuantize(Module): | |||||
| def __init__(self, dtype: str, enable: bool = True): | def __init__(self, dtype: str, enable: bool = True): | ||||
| super().__init__() | super().__init__() | ||||
| if not dtype in _metadata_dict.keys(): | if not dtype in _metadata_dict.keys(): | ||||
| @@ -35,25 +38,103 @@ class FakeQuantize(Module): | |||||
| def disable(self): | def disable(self): | ||||
| self.enabled = False | self.enabled = False | ||||
| def fake_quant_forward(self, inp, q_dict): | |||||
| return inp | |||||
| def normal_foward(self, inp, q_dict): | |||||
| return inp | |||||
| def forward(self, inp, q_dict): | def forward(self, inp, q_dict): | ||||
| if self.enabled: | if self.enabled: | ||||
| if q_dict["mode"] == ObserverMode.SYMMERTIC: | |||||
| scale = q_dict["scale"] | |||||
| # Quant | |||||
| oup = Round()(inp / scale) | |||||
| # clip | |||||
| oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
| # DeQuant | |||||
| oup = (oup) * scale | |||||
| return oup | |||||
| else: | |||||
| scale = q_dict["scale"] | |||||
| zero_point = q_dict["zero_point"] | |||||
| # Quant | |||||
| oup = Round()(inp / scale) + zero_point | |||||
| # clip | |||||
| oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
| # DeQuant | |||||
| oup = (oup - zero_point) * scale | |||||
| return oup | |||||
| return self.fake_quant_forward(inp, q_dict) | |||||
| else: | |||||
| return self.normal_foward(inp, q_dict) | |||||
| class TQT_Function(Function): | |||||
| def __init__(self, lowerbound, upperbound): | |||||
| super().__init__() | |||||
| self.lowerbound = lowerbound | |||||
| self.upperbound = upperbound | |||||
| def forward(self, inp, scale): | |||||
| t = 2 ** scale | |||||
| # t = F.maximum(t, 1e-4) | |||||
| inp_scaled = inp / t | |||||
| inp_clipped = F.maximum(F.minimum(inp_scaled, self.upperbound), self.lowerbound) | |||||
| inp_rounded = F.round(inp_clipped) | |||||
| inp_flq = inp_rounded * t | |||||
| self.save_for_backward(inp_scaled, inp_rounded, t) | |||||
| return inp_flq | |||||
| def backward(self, grad_inp_flq): | |||||
| (inp_scaled, inp_rounded, t) = self.saved_tensors | |||||
| mask_clip = (inp_scaled < -0.5 + self.lowerbound) + ( | |||||
| inp_scaled > self.upperbound + 0.5 | |||||
| ) # mask for accumulating the gradients of |data_scaled|>L | |||||
| mask_quant = F.abs( | |||||
| mask_clip - 1 | |||||
| ) # mask for accumulating the gradients with |data_scaled|<=L | |||||
| grad_quant = ( | |||||
| grad_inp_flq * mask_quant * (inp_rounded - inp_scaled) | |||||
| ) # gradient within |data_scaled|<=L | |||||
| grad_clip = ( | |||||
| grad_inp_flq * mask_clip * inp_rounded | |||||
| ) # gradient with | data_scaled|>L | |||||
| grad_s = grad_clip.sum() + grad_quant.sum() | |||||
| # dL/ds = dL/dt * t * ln(2) | |||||
| grad_s = grad_s * t * math.log(2) | |||||
| grad_inp = grad_inp_flq * mask_quant | |||||
| return grad_inp, grad_s | |||||
| class TQT(_FakeQuantize): | |||||
| """ | |||||
| TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds | |||||
| for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks | |||||
| """ | |||||
| def __init__(self, dtype: str, enable: bool = True): | |||||
| super().__init__(dtype, enable) | |||||
| self.scale = Parameter(0.0, dtype=np.float32) | |||||
| def fake_quant_forward(self, inp, q_dict): | |||||
| # when enable, TQT will do fakequant forward, finetune the scale | |||||
| return TQT_Function(self.qmin, self.qmax)(inp, self.scale) | |||||
| def normal_foward(self, inp, q_dict): | |||||
| # when disable, TQT will do normal forward, initialize scale weight | |||||
| tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) | |||||
| tmp_scale = F.log(tmp_scale / 127) / F.log(2) | |||||
| F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | |||||
| return inp | return inp | ||||
| def get_dtype(self): | |||||
| return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None) | |||||
| class FakeQuantize(_FakeQuantize): | |||||
| r""" | |||||
| A module to do quant and dequant according to observer's scale and zero_point. | |||||
| """ | |||||
| def fake_quant_forward(self, inp, q_dict): | |||||
| if q_dict["mode"] == ObserverMode.SYMMERTIC: | |||||
| scale = q_dict["scale"] | |||||
| # Quant | |||||
| oup = Round()(inp / scale) | |||||
| # clip | |||||
| oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
| # DeQuant | |||||
| oup = (oup) * scale | |||||
| return oup | |||||
| else: | |||||
| scale = q_dict["scale"] | |||||
| zero_point = q_dict["zero_point"] | |||||
| # Quant | |||||
| oup = Round()(inp / scale) + zero_point | |||||
| # clip | |||||
| oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
| # DeQuant | |||||
| oup = (oup - zero_point) * scale | |||||
| return oup | |||||
| @@ -107,6 +107,8 @@ class MinMaxObserver(Observer): | |||||
| min_val = F.minimum(0.0, inp_min_val) | min_val = F.minimum(0.0, inp_min_val) | ||||
| max_val = F.maximum(0.0, inp_max_val) | max_val = F.maximum(0.0, inp_max_val) | ||||
| q_dict = create_observer_dict(self.mode) | q_dict = create_observer_dict(self.mode) | ||||
| q_dict["min_val"] = inp_min_val | |||||
| q_dict["max_val"] = inp_max_val | |||||
| if self.mode == ObserverMode.SYMMERTIC: | if self.mode == ObserverMode.SYMMERTIC: | ||||
| symmetric_max_vals = F.maximum(-min_val, max_val) | symmetric_max_vals = F.maximum(-min_val, max_val) | ||||
| # use maximun to avoid scale too small at the begin | # use maximun to avoid scale too small at the begin | ||||
| @@ -1,12 +1,12 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
| # | # | ||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
| # | |||||
| #' | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from ..module import Module | from ..module import Module | ||||
| from .fake_quant import FakeQuantize | |||||
| from .fake_quant import TQT, FakeQuantize | |||||
| from .observer import ( | from .observer import ( | ||||
| ExponentialMovingAverageObserver, | ExponentialMovingAverageObserver, | ||||
| HistogramObserver, | HistogramObserver, | ||||
| @@ -52,6 +52,12 @@ class QConfig: | |||||
| self.fake_quant = fake_quant | self.fake_quant = fake_quant | ||||
| tqt_quant_qconfig = QConfig( | |||||
| weight_observer=ExponentialMovingAverageObserver, | |||||
| act_observer=ExponentialMovingAverageObserver, | |||||
| fake_quant=TQT, | |||||
| ) | |||||
| # Default QAT QConfigs | # Default QAT QConfigs | ||||
| min_max_fakequant_qconfig = QConfig( | min_max_fakequant_qconfig = QConfig( | ||||
| weight_observer=MinMaxObserver, | weight_observer=MinMaxObserver, | ||||
| @@ -96,7 +96,6 @@ def test_deepcopy(): | |||||
| origin = Sigmoid(0) | origin = Sigmoid(0) | ||||
| new = copy.deepcopy(Sigmoid(0)) | new = copy.deepcopy(Sigmoid(0)) | ||||
| assert new.param == origin.param | assert new.param == origin.param | ||||
| assert new.saved_tensors == None | |||||
| def test_save_context(): | def test_save_context(): | ||||
| @@ -0,0 +1,77 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import numpy as np | |||||
| import pytest | |||||
| import megengine as mge | |||||
| import megengine._internal as mgb | |||||
| from megengine.core import tensor | |||||
| from megengine.quantization.fake_quant import TQT_Function | |||||
| from megengine.test import assertTensorClose | |||||
| class numpy_TQT_Function: | |||||
| def __init__(self, lowerbound, upperbound): | |||||
| super().__init__() | |||||
| self.lowerbound = lowerbound | |||||
| self.upperbound = upperbound | |||||
| def forward(self, inp, scale): | |||||
| t = 2 ** scale | |||||
| # t = F.maximum(t, 1e-4) | |||||
| inp_scaled = inp / t | |||||
| inp_clipped = np.maximum( | |||||
| np.minimum(inp_scaled, self.upperbound), self.lowerbound | |||||
| ) | |||||
| inp_rounded = np.round(inp_clipped) | |||||
| inp_flq = inp_rounded * t | |||||
| self.saved_tensors = (inp_scaled, inp_rounded, t) | |||||
| return inp_flq | |||||
| def backward(self, grad_inp_flq): | |||||
| (inp_scaled, inp_rounded, t) = self.saved_tensors | |||||
| mask_clip = (inp_scaled < -0.5 + self.lowerbound) + ( | |||||
| inp_scaled > self.upperbound + 0.5 | |||||
| ) # mask for accumulating the gradients of |data_scaled|>L | |||||
| mask_quant = np.abs( | |||||
| mask_clip - 1 | |||||
| ) # mask for accumulating the gradients with |data_scaled|<=L | |||||
| grad_quant = ( | |||||
| grad_inp_flq * mask_quant * (inp_rounded - inp_scaled) | |||||
| ) # gradient within |data_scaled|<=L | |||||
| grad_clip = ( | |||||
| grad_inp_flq * mask_clip * inp_rounded | |||||
| ) # gradient with | data_scaled|>L | |||||
| grad_s = grad_clip.sum() + grad_quant.sum() | |||||
| # dL/ds = dL/dt * t * ln(2) | |||||
| grad_s = grad_s * t * np.log(2) | |||||
| grad_inp = grad_inp_flq * mask_quant | |||||
| return grad_inp, grad_s | |||||
| def test_TQT(): | |||||
| f = TQT_Function(-127, 127) | |||||
| nf = numpy_TQT_Function(-127, 127) | |||||
| def check_inp(a, b, c, a_np, b_np, c_np): | |||||
| assertTensorClose( | |||||
| f.forward(a, b).numpy(), nf.forward(a_np, b_np).astype("float32") | |||||
| ) | |||||
| c1, c2 = f.backward(c) | |||||
| c1_np, c2_np = nf.backward(c_np) | |||||
| assertTensorClose(c1.numpy(), c1_np.astype("float32")) | |||||
| assertTensorClose(c2.numpy(), c2_np.astype("float32")) | |||||
| a = tensor() | |||||
| b = tensor() | |||||
| a_np = np.random.random((4, 3)).astype("float32") | |||||
| b_np = np.random.random((1)).astype("float32") | |||||
| a.set_value(a_np) | |||||
| b.set_value(b_np) | |||||
| check_inp(a, b, b, a_np, b_np, b_np) | |||||