GitOrigin-RevId: 6e39de9cec
tags/v1.0.0-rc1
| @@ -17,9 +17,7 @@ class Elemwise(Float.Elemwise, QATModule): | |||
| :param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail. | |||
| """ | |||
| def __init__(self, method): | |||
| super().__init__(method) | |||
| self.with_weight = False | |||
| with_weight = False | |||
| def forward(self, *inps): | |||
| return self.apply_quant_activation(super().forward(*inps)) | |||
| @@ -23,6 +23,9 @@ class QATModule(Module): | |||
| :func:`~.quantize.quantize` further. | |||
| """ | |||
| with_weight = True | |||
| with_act = True | |||
| def __init__(self): | |||
| super().__init__() | |||
| @@ -32,9 +35,6 @@ class QATModule(Module): | |||
| self.weight_fake_quant = None # type: FakeQuantize | |||
| self.act_fake_quant = None # type: FakeQuantize | |||
| self.with_weight = True | |||
| self.with_act = True | |||
| def set_qconfig(self, qconfig: QConfig): | |||
| r""" | |||
| Set quantization related configs with ``qconfig``, including | |||
| @@ -51,29 +51,21 @@ class QATModule(Module): | |||
| self.weight_observer = safe_call(qconfig.weight_observer) | |||
| self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) | |||
| def _enable_exec(self, with_module, func, enable): | |||
| if not with_module: | |||
| return | |||
| if enable: | |||
| func.enable() | |||
| else: | |||
| func.disable() | |||
| def set_fake_quant(self, enable): | |||
| if self.with_act: | |||
| if enable: | |||
| self.act_fake_quant.enable() | |||
| else: | |||
| self.act_fake_quant.disable() | |||
| if self.with_weight: | |||
| if enable: | |||
| self.weight_fake_quant.enable() | |||
| else: | |||
| self.weight_fake_quant.disable() | |||
| self._enable_exec(self.with_act, self.act_fake_quant, enable) | |||
| self._enable_exec(self.with_weight, self.weight_fake_quant, enable) | |||
| def set_observer(self, enable): | |||
| if self.with_act: | |||
| if enable: | |||
| self.act_observer.enable() | |||
| else: | |||
| self.act_observer.disable() | |||
| if self.with_weight: | |||
| if enable: | |||
| self.weight_observer.enable() | |||
| else: | |||
| self.weight_observer.disable() | |||
| self._enable_exec(self.with_act, self.act_observer, enable) | |||
| self._enable_exec(self.with_weight, self.weight_observer, enable) | |||
| def _apply_fakequant_with_observer( | |||
| self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | |||
| @@ -15,9 +15,7 @@ class QuantStub(Float.QuantStub, QATModule): | |||
| input after converted to :class:`~.QuantizedModule`. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.with_weight = False | |||
| with_weight = False | |||
| def forward(self, inp): | |||
| return self.apply_quant_activation(inp) | |||
| @@ -37,10 +35,8 @@ class DequantStub(Float.DequantStub, QATModule): | |||
| input after converted to :class:`~.QuantizedModule`. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.with_weight = False | |||
| self.with_act = False | |||
| with_weight = False | |||
| with_act = False | |||
| def forward(self, inp): | |||
| return inp | |||
| @@ -116,10 +116,11 @@ class TQT(_FakeQuantize): | |||
| return TQT_Function(self.qmin, self.qmax)(inp, self.scale) | |||
| def normal_foward(self, inp, q_dict): | |||
| # when disable, TQT will do normal forward, initialize scale weight | |||
| tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) | |||
| tmp_scale = F.log(tmp_scale / 127) / F.log(2) | |||
| F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | |||
| if q_dict["enable_observer"]: | |||
| # when disable, TQT will do normal forward, initialize scale weight | |||
| tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) | |||
| tmp_scale = F.log(tmp_scale / 127) / F.log(2) | |||
| F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | |||
| return inp | |||
| def get_qparams(self): | |||
| @@ -102,6 +102,7 @@ class MinMaxObserver(Observer): | |||
| q_dict = get_qparam_dict(self.mode) | |||
| q_dict["min_val"] = inp_min_val | |||
| q_dict["max_val"] = inp_max_val | |||
| q_dict["enable_observer"] = self.enable | |||
| if self.mode == QuantMode.SYMMERTIC: | |||
| symmetric_max_vals = F.maximum(-min_val, max_val) | |||
| # use maximun to avoid scale too small at the begin | |||
| @@ -14,6 +14,7 @@ from ..module import qat as QAT | |||
| from ..module import quantized as Quantized | |||
| from ..module.qat import QATModule | |||
| from ..module.quantized import QuantizedModule | |||
| from .fake_quant import TQT | |||
| from .qconfig import QConfig, ema_fakequant_qconfig | |||
| @@ -119,6 +120,14 @@ def quantize_qat( | |||
| return module | |||
| def _propagate(module: Module, func_str: str, *args, **kargs): | |||
| def fn(mod: Module): | |||
| if isinstance(mod, QATModule): | |||
| getattr(mod, func_str)(*args, **kargs) | |||
| module.apply(fn) | |||
| def propagate_qconfig(module: QATModule, qconfig: QConfig): | |||
| r""" | |||
| Recursively set ``module``'s qconfig through :meth:`~.Module.apply`. | |||
| @@ -126,12 +135,7 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig): | |||
| :param module: root module to traverse recursively. | |||
| :param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. | |||
| """ | |||
| def fn(mod: Module): | |||
| if isinstance(mod, QATModule): | |||
| mod.set_qconfig(qconfig) | |||
| module.apply(fn) | |||
| _propagate(module, "set_qconfig", qconfig) | |||
| def disable_fake_quant(module: Module): | |||
| @@ -141,11 +145,7 @@ def disable_fake_quant(module: Module): | |||
| :param module: root module to do disable fake quantization recursively. | |||
| """ | |||
| def fn(mod: Module): | |||
| if isinstance(mod, QATModule): | |||
| mod.set_fake_quant(False) | |||
| module.apply(fn) | |||
| _propagate(module, "set_fake_quant", False) | |||
| def disable_observer(module: Module): | |||
| @@ -155,11 +155,7 @@ def disable_observer(module: Module): | |||
| :param module: root module to do disable observer recursively. | |||
| """ | |||
| def fn(mod: Module): | |||
| if isinstance(mod, QATModule): | |||
| self.set_observer(False) | |||
| module.apply(fn) | |||
| _propagate(module, "set_observer", False) | |||
| def enable_fake_quant(module: Module): | |||
| @@ -169,11 +165,7 @@ def enable_fake_quant(module: Module): | |||
| :param module: root module to do enable fake quantization recursively. | |||
| """ | |||
| def fn(mod: Module): | |||
| if isinstance(mod, QATModule): | |||
| mod.set_fake_quant(True) | |||
| module.apply(fn) | |||
| _propagate(module, "set_fake_quant", True) | |||
| def enable_observer(module: Module): | |||
| @@ -183,8 +175,4 @@ def enable_observer(module: Module): | |||
| :param module: root module to do enable observer recursively. | |||
| """ | |||
| def fn(mod: Module): | |||
| if isinstance(mod, QATModule): | |||
| mod.set_observer(True) | |||
| module.apply(fn) | |||
| _propagate(module, "set_observer", False) | |||