GitOrigin-RevId: f16fbba2b7
tags/v0.5.0
| @@ -496,8 +496,11 @@ class QATModule(Module): | |||||
| self, target: Tensor, fq: "FakeQuantize", obs: "Observer" | self, target: Tensor, fq: "FakeQuantize", obs: "Observer" | ||||
| ): | ): | ||||
| oup = self.apply_observer(target, obs) | oup = self.apply_observer(target, obs) | ||||
| scale, zero_point = obs.get_qparams() | |||||
| return fq(oup, scale, zero_point) | |||||
| if self.quantizing == self.QATMode.CALIBRATION: | |||||
| return oup | |||||
| else: | |||||
| scale, zero_point = obs.get_qparams() | |||||
| return fq(oup, scale, zero_point) | |||||
| def set_qat_mode(self, mode: QATMode): | def set_qat_mode(self, mode: QATMode): | ||||
| r""" | r""" | ||||
| @@ -524,11 +527,7 @@ class QATModule(Module): | |||||
| """ | """ | ||||
| def __call__(self, *args, **kwargs): | def __call__(self, *args, **kwargs): | ||||
| if self.quantizing == self.QATMode.QAT: | |||||
| return self.forward_qat(*args, **kwargs) | |||||
| elif self.quantizing == self.QATMode.CALIBRATION: | |||||
| # TODO implement the CALIBRATION | |||||
| assert False | |||||
| return None | |||||
| else: | |||||
| if self.quantizing == self.QATMode.DISABLED: | |||||
| return self.forward(*args, **kwargs) | return self.forward(*args, **kwargs) | ||||
| else: | |||||
| return self.forward_qat(*args, **kwargs) | |||||
| @@ -20,11 +20,9 @@ class Concat(Module): | |||||
| A :class:`~.Module` to do quantized concat, inference only. | A :class:`~.Module` to do quantized concat, inference only. | ||||
| """ | """ | ||||
| def __init__(self): | |||||
| def __init__(self, dtype=None): | |||||
| super().__init__() | super().__init__() | ||||
| self.scale = 1.0 | |||||
| self.zero_point = 0.0 | |||||
| self.output_dtype = mgb.dtype.qint8(self.scale) | |||||
| self.output_dtype = dtype | |||||
| def forward(self, inps: Iterable[Tensor], axis: int = 0): | def forward(self, inps: Iterable[Tensor], axis: int = 0): | ||||
| if self.training: | if self.training: | ||||
| @@ -39,7 +37,4 @@ def to_quantized(float_module): | |||||
| Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | ||||
| implemented here to avoid circular import. | implemented here to avoid circular import. | ||||
| """ | """ | ||||
| qmod = Concat() | |||||
| qmod.output_dtype = float_module.act_observer.get_dtype() | |||||
| qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() | |||||
| return qmod | |||||
| return Concat(float_module.act_observer.get_dtype()) | |||||
| @@ -34,6 +34,7 @@ class _ConvBnActivation2d(Conv2d): | |||||
| groups: int = 1, | groups: int = 1, | ||||
| conv_mode: str = "CROSS_CORRELATION", | conv_mode: str = "CROSS_CORRELATION", | ||||
| compute_mode: str = "DEFAULT", | compute_mode: str = "DEFAULT", | ||||
| dtype=None, | |||||
| ): | ): | ||||
| super().__init__( | super().__init__( | ||||
| in_channels, | in_channels, | ||||
| @@ -47,11 +48,7 @@ class _ConvBnActivation2d(Conv2d): | |||||
| conv_mode, | conv_mode, | ||||
| compute_mode, | compute_mode, | ||||
| ) | ) | ||||
| self.scale = 1.0 | |||||
| self.zero_point = 0.0 | |||||
| self.output_dtype = mgb.dtype.qint8(self.scale) | |||||
| self.weight = self.weight.astype(self.output_dtype) | |||||
| self.bias = self.bias.astype(mgb.dtype.qint32(self.scale)) | |||||
| self.output_dtype = dtype | |||||
| def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"): | def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"): | ||||
| inp_scale = mgb.dtype.get_scale(inp.dtype) | inp_scale = mgb.dtype.get_scale(inp.dtype) | ||||
| @@ -87,6 +84,7 @@ class ConvBnRelu2d(_ConvBnActivation2d): | |||||
| def to_quantized(quantized_class, float_module): | def to_quantized(quantized_class, float_module): | ||||
| output_dtype = float_module.act_observer.get_dtype() | |||||
| qconv = quantized_class( | qconv = quantized_class( | ||||
| float_module.conv.in_channels, | float_module.conv.in_channels, | ||||
| float_module.conv.out_channels, | float_module.conv.out_channels, | ||||
| @@ -95,15 +93,14 @@ def to_quantized(quantized_class, float_module): | |||||
| float_module.conv.padding, | float_module.conv.padding, | ||||
| float_module.conv.dilation, | float_module.conv.dilation, | ||||
| float_module.conv.groups, | float_module.conv.groups, | ||||
| dtype=output_dtype, | |||||
| ) | ) | ||||
| w_fold, b_fold = float_module.fold_weight_bias( | w_fold, b_fold = float_module.fold_weight_bias( | ||||
| float_module.bn.running_mean, float_module.bn.running_var | float_module.bn.running_mean, float_module.bn.running_var | ||||
| ) | ) | ||||
| weight = w_fold.astype(float_module.weight_observer.get_dtype()) | weight = w_fold.astype(float_module.weight_observer.get_dtype()) | ||||
| qconv.output_dtype = float_module.act_observer.get_dtype() | |||||
| qconv.weight = Parameter(weight.numpy()) | qconv.weight = Parameter(weight.numpy()) | ||||
| qconv.bias = Parameter(b_fold.numpy()) | qconv.bias = Parameter(b_fold.numpy()) | ||||
| qconv.scale, qconv.zero_point = float_module.act_observer.get_qparams() | |||||
| return qconv | return qconv | ||||
| @@ -34,12 +34,10 @@ class Elemwise(Module): | |||||
| _elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode | _elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode | ||||
| def __init__(self, method): | |||||
| def __init__(self, method, dtype=None): | |||||
| super().__init__() | super().__init__() | ||||
| self.method = self._elemwise_multi_type_mode.convert("Q" + method) | self.method = self._elemwise_multi_type_mode.convert("Q" + method) | ||||
| self.scale = 1.0 | |||||
| self.zero_point = 0.0 | |||||
| self.output_dtype = mgb.dtype.qint8(self.scale) | |||||
| self.output_dtype = dtype | |||||
| def forward(self, *inps): | def forward(self, *inps): | ||||
| if self.training: | if self.training: | ||||
| @@ -53,7 +51,4 @@ def to_quantized(float_module): | |||||
| Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | ||||
| implemented here to avoid circular import. | implemented here to avoid circular import. | ||||
| """ | """ | ||||
| qmod = Elemwise(float_module.method.name) | |||||
| qmod.output_dtype = float_module.act_observer.get_dtype() | |||||
| qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() | |||||
| return qmod | |||||
| return Elemwise(float_module.method.name, float_module.act_observer.get_dtype()) | |||||
| @@ -16,11 +16,9 @@ class QuantStub(Module): | |||||
| A helper quantize operation on input and inference only. | A helper quantize operation on input and inference only. | ||||
| """ | """ | ||||
| def __init__(self): | |||||
| def __init__(self, dtype=None): | |||||
| super().__init__() | super().__init__() | ||||
| self.scale = 1.0 | |||||
| self.zero_point = 0.0 | |||||
| self.output_dtype = mgb.dtype.qint8(self.scale) | |||||
| self.output_dtype = dtype | |||||
| def forward(self, inp): | def forward(self, inp): | ||||
| if self.training: | if self.training: | ||||
| @@ -45,10 +43,7 @@ def to_quantized(float_module): | |||||
| Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | ||||
| implemented here to avoid circular import. | implemented here to avoid circular import. | ||||
| """ | """ | ||||
| qmod = QuantStub() | |||||
| qmod.output_dtype = float_module.act_observer.get_dtype() | |||||
| qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() | |||||
| return qmod | |||||
| return QuantStub(float_module.act_observer.get_dtype()) | |||||
| @register_method_to_class(Float.DequantStub) | @register_method_to_class(Float.DequantStub) | ||||
| @@ -57,5 +52,4 @@ def to_quantized(float_module): | |||||
| Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | ||||
| implemented here to avoid circular import. | implemented here to avoid circular import. | ||||
| """ | """ | ||||
| qmod = DequantStub() | |||||
| return qmod | |||||
| return DequantStub() | |||||
| @@ -14,5 +14,6 @@ from .quantize import ( | |||||
| enable_fake_quant, | enable_fake_quant, | ||||
| enable_observer, | enable_observer, | ||||
| quantize, | quantize, | ||||
| quantize_calibration, | |||||
| quantize_qat, | quantize_qat, | ||||
| ) | ) | ||||
| @@ -11,7 +11,7 @@ import numpy as np | |||||
| from .. import functional as F | from .. import functional as F | ||||
| from .._internal.dtype import _metadata_dict, get_quantized_dtype | from .._internal.dtype import _metadata_dict, get_quantized_dtype | ||||
| from ..core import Buffer, Function, ones, tensor, zeros | |||||
| from ..core import Buffer, Function, tensor | |||||
| from ..module import Module | from ..module import Module | ||||
| @@ -34,6 +34,8 @@ def quantize(module: Module, inplace=True): | |||||
| else: | else: | ||||
| setattr(parent, key.split(".")[-1], submodule.to_quantized()) | setattr(parent, key.split(".")[-1], submodule.to_quantized()) | ||||
| return module | |||||
| def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): | def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): | ||||
| r""" | r""" | ||||
| @@ -53,6 +55,25 @@ def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): | |||||
| module.apply(fn) | module.apply(fn) | ||||
| def quantize_calibration(module: Module, qconfig: QConfig = ema_fakequant_qconfig): | |||||
| r""" | |||||
| Recursively convert `module` to `calibration` mode through :meth:`~.Module.apply` | |||||
| and set qconfig relatively. | |||||
| :param module: root module to do convert recursively. | |||||
| :param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. | |||||
| default is :any:`~.qconfig.ema_fakequant_qconfig`. | |||||
| """ | |||||
| def fn(mod: Module): | |||||
| if isinstance(mod, QATModule): | |||||
| mod.set_qat_mode(QATModule.QATMode.CALIBRATION) | |||||
| mod.set_qconfig(qconfig) | |||||
| module.apply(fn) | |||||
| enable_observer(module) | |||||
| def disable_fake_quant(module: Module): | def disable_fake_quant(module: Module): | ||||
| r""" | r""" | ||||
| Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply` | Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply` | ||||