GitOrigin-RevId: 80cfb12d10
tags/v0.5.0
| @@ -27,10 +27,10 @@ from .utils import _decide_comp_node_and_comp_graph | |||
| def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: | |||
| """Applies a linear transformation to the input. | |||
| Refer to :class:`~.Linear` for more information. | |||
| Refer to :class:`~.module.linear.Linear` for more information. | |||
| :param inp: the input tensor with shape `(N, in_features)`. | |||
| :param weight: the weight with shape `(out_features, in_features)`. | |||
| :param weight: the weight with shape `(out_features, in_features)`. | |||
| :param bias: the bias with shape `(out_features,)`. | |||
| Default: ``None`` | |||
| """ | |||
| @@ -300,9 +300,9 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: | |||
| def softplus(inp: Tensor, beta: float = 1, threshold: float = 20) -> Tensor: | |||
| r""" | |||
| Performs the elementwise function: | |||
| .. math:: | |||
| \mathsf{softplus}(x) = \log(1+\exp(\beta x)) / \beta. | |||
| For numerical stability the identity function is used when :math:`\beta x > \textrm{threshold}`. | |||
| @@ -16,7 +16,7 @@ from .elemwise import Elemwise | |||
| from .embedding import Embedding | |||
| from .identity import Identity | |||
| from .linear import Linear | |||
| from .module import Module, QATModule | |||
| from .module import Module | |||
| from .parampack import ParamPack | |||
| from .pooling import AvgPool2d, MaxPool2d | |||
| from .quant_dequant import DequantStub, QuantStub | |||
| @@ -9,19 +9,14 @@ from typing import Iterable | |||
| from .. import functional as F | |||
| from ..core.tensor import Tensor | |||
| from .module import QATModule | |||
| from .module import Module | |||
| class Concat(QATModule): | |||
| class Concat(Module): | |||
| r""" | |||
| A :class:`~.QATModule` to do functional concat, should replace concat with this module, | |||
| supporting ``qat`` mode and ``quantized`` mode. | |||
| A :class:`~.Module` to do functional concat. Could be replaced with :class:`~.QATModule` | |||
| version :class:`~.qat.concat.Concat` using :func:`~.quantize.quantize_qat`. | |||
| """ | |||
| def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||
| return F.concat(inps, axis) | |||
| def forward_qat(self, inps: Iterable[Tensor], axis: int = 0): | |||
| return self.apply_fakequant_with_observer( | |||
| self.forward(inps, axis), self.act_fake_quant, self.act_observer | |||
| ) | |||
| @@ -7,14 +7,13 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Tuple, Union | |||
| from ..core import ones, zeros | |||
| from ..functional import add_update, flatten, relu, sqrt, sum, zero_grad | |||
| from ..functional import relu | |||
| from .batchnorm import BatchNorm2d | |||
| from .conv import Conv2d | |||
| from .module import QATModule | |||
| from .module import Module | |||
| class _ConvBn2d(QATModule): | |||
| class _ConvBnActivation2d(Module): | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| @@ -47,171 +46,24 @@ class _ConvBn2d(QATModule): | |||
| ) | |||
| self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | |||
| def get_batch_mean_var(self, inp): | |||
| def _sum_channel(inp, axis=0, keepdims=True): | |||
| if isinstance(axis, int): | |||
| out = sum(inp, axis=axis, keepdims=keepdims) | |||
| elif isinstance(axis, tuple): | |||
| for idx, elem in enumerate(axis): | |||
| out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) | |||
| return out | |||
| sum1 = _sum_channel(inp, (0, 2, 3)) | |||
| sum2 = _sum_channel(inp ** 2, (0, 2, 3)) | |||
| reduce_size = inp.shapeof().prod() / inp.shapeof(1) | |||
| batch_mean = sum1 / reduce_size | |||
| batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size | |||
| return batch_mean, batch_var | |||
| def fold_weight_bias(self, bn_mean, bn_var): | |||
| # get fold bn conv param | |||
| # bn_istd = 1 / bn_std | |||
| # w_fold = gamma / bn_std * W | |||
| # b_fold = gamma * (b - bn_mean) / bn_std + beta | |||
| gamma = self.bn.weight | |||
| if gamma is None: | |||
| gamma = ones((self.bn.num_features), dtype="float32") | |||
| gamma = gamma.reshape(1, -1, 1, 1) | |||
| beta = self.bn.bias | |||
| if beta is None: | |||
| beta = zeros((self.bn.num_features), dtype="float32") | |||
| beta = beta.reshape(1, -1, 1, 1) | |||
| if bn_mean is None: | |||
| bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||
| if bn_var is None: | |||
| bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||
| conv_bias = self.conv.bias | |||
| if conv_bias is None: | |||
| conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
| # bn_istd = 1 / bn_std | |||
| # w_fold = gamma / bn_std * W | |||
| scale_factor = gamma * bn_istd | |||
| if self.conv.groups == 1: | |||
| w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||
| else: | |||
| w_fold = self.conv.weight * scale_factor.reshape( | |||
| self.conv.groups, -1, 1, 1, 1 | |||
| ) | |||
| # b_fold = gamma * (b - bn_mean) / bn_std + beta | |||
| b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||
| return w_fold, b_fold | |||
| def update_running_mean_and_running_var( | |||
| self, bn_mean, bn_var, num_elements_per_channel | |||
| ): | |||
| # update running mean and running var. no grad, use unbiased bn var | |||
| bn_mean = zero_grad(bn_mean) | |||
| bn_var = ( | |||
| zero_grad(bn_var) | |||
| * num_elements_per_channel | |||
| / (num_elements_per_channel - 1) | |||
| ) | |||
| exponential_average_factor = 1 - self.bn.momentum | |||
| add_update( | |||
| self.bn.running_mean, | |||
| delta=bn_mean, | |||
| alpha=1 - exponential_average_factor, | |||
| beta=exponential_average_factor, | |||
| ) | |||
| add_update( | |||
| self.bn.running_var, | |||
| delta=bn_var, | |||
| alpha=1 - exponential_average_factor, | |||
| beta=exponential_average_factor, | |||
| ) | |||
| def calc_conv_bn_qat(self, inp, approx=True): | |||
| if self.training and not approx: | |||
| conv = self.conv(inp) | |||
| bn_mean, bn_var = self.get_batch_mean_var(conv) | |||
| num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) | |||
| self.update_running_mean_and_running_var( | |||
| bn_mean, bn_var, num_elements_per_channel | |||
| ) | |||
| else: | |||
| bn_mean, bn_var = self.bn.running_mean, self.bn.running_var | |||
| # get gamma and beta in BatchNorm | |||
| gamma = self.bn.weight | |||
| if gamma is None: | |||
| gamma = ones((self.bn.num_features), dtype="float32") | |||
| gamma = gamma.reshape(1, -1, 1, 1) | |||
| beta = self.bn.bias | |||
| if beta is None: | |||
| beta = zeros((self.bn.num_features), dtype="float32") | |||
| beta = beta.reshape(1, -1, 1, 1) | |||
| # conv_bias | |||
| conv_bias = self.conv.bias | |||
| if conv_bias is None: | |||
| conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
| # bn_istd = 1 / bn_std | |||
| # w_fold = gamma / bn_std * W | |||
| scale_factor = gamma * bn_istd | |||
| if self.conv.groups == 1: | |||
| w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||
| else: | |||
| w_fold = self.conv.weight * scale_factor.reshape( | |||
| self.conv.groups, -1, 1, 1, 1 | |||
| ) | |||
| b_fold = None | |||
| if not (self.training and approx): | |||
| # b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta | |||
| b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||
| w_qat = self.apply_fakequant_with_observer( | |||
| w_fold, self.weight_fake_quant, self.weight_observer | |||
| ) | |||
| conv = self.conv.calc_conv(inp, w_qat, b_fold) | |||
| if not (self.training and approx): | |||
| return conv | |||
| # rescale conv to get original conv output | |||
| orig_conv = conv / scale_factor.reshape(1, -1, 1, 1) | |||
| if self.conv.bias is not None: | |||
| orig_conv = orig_conv + self.conv.bias | |||
| # calculate batch norm | |||
| bn_mean, bn_var = self.get_batch_mean_var(orig_conv) | |||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
| conv = gamma * bn_istd * (orig_conv - bn_mean) + beta | |||
| num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) | |||
| self.update_running_mean_and_running_var( | |||
| bn_mean, bn_var, num_elements_per_channel | |||
| ) | |||
| return conv | |||
| class ConvBn2d(_ConvBn2d): | |||
| class ConvBn2d(_ConvBnActivation2d): | |||
| r""" | |||
| A fused :class:`~.QATModule` including Conv2d and BatchNorm2d, supporting ``qat`` mode | |||
| and ``normal`` mode. | |||
| A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced | |||
| with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBn2d` using | |||
| :func:`~.quantize.quantize_qat`. | |||
| """ | |||
| def forward_qat(self, inp): | |||
| return self.apply_fakequant_with_observer( | |||
| self.calc_conv_bn_qat(inp), self.act_fake_quant, self.act_observer | |||
| ) | |||
| def forward(self, inp): | |||
| return self.bn(self.conv(inp)) | |||
| class ConvBnRelu2d(_ConvBn2d): | |||
| class ConvBnRelu2d(_ConvBnActivation2d): | |||
| r""" | |||
| A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu, supporting ``qat`` | |||
| mode and ``normal`` mode. | |||
| A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced | |||
| with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBnRelu2d` using | |||
| :func:`~.quantize.quantize_qat`. | |||
| """ | |||
| def forward_qat(self, inp): | |||
| return self.apply_fakequant_with_observer( | |||
| relu(self.calc_conv_bn_qat(inp)), self.act_fake_quant, self.act_observer | |||
| ) | |||
| def forward(self, inp): | |||
| return relu(self.bn(self.conv(inp))) | |||
| @@ -8,7 +8,7 @@ | |||
| from .. import _internal as mgb | |||
| from ..core import Tensor, wrap_io_tensor | |||
| from ..core.graph import _use_default_if_none | |||
| from .module import QATModule | |||
| from .module import Module | |||
| @wrap_io_tensor | |||
| @@ -22,10 +22,10 @@ def _elemwise_func(mode, *inputs, **kwargs) -> Tensor: | |||
| return mgb.opr.elemwise(*inputs, mode=mode, **kwargs) | |||
| class Elemwise(QATModule): | |||
| class Elemwise(Module): | |||
| r""" | |||
| A :class:`~.QATModule` to do elemwise operator, should functional operator with this module, | |||
| supporting ``qat`` mode and ``normal`` mode. | |||
| A :class:`~.Module` to do elemwise operator. Could be replaced with :class:`~.QATModule` | |||
| version :class:`~.qat.elemwise.Elemwise` using :func:`~.quantize.quantize_qat`. | |||
| :param method: the elemwise method, support the following string. | |||
| It will do the normal elemwise operator for float. | |||
| @@ -88,8 +88,3 @@ class Elemwise(QATModule): | |||
| def forward(self, *inps): | |||
| return _elemwise_func(self.method, *inps) | |||
| def forward_qat(self, *inps): | |||
| return self.apply_fakequant_with_observer( | |||
| self.forward(*inps), self.act_fake_quant, self.act_observer, | |||
| ) | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| @@ -11,10 +10,10 @@ import numpy as np | |||
| from .. import functional as F | |||
| from ..core import Parameter | |||
| from . import init | |||
| from .module import QATModule | |||
| from .module import Module | |||
| class Linear(QATModule): | |||
| class Linear(Module): | |||
| r"""Applies a linear transformation to the input. For instance, if input | |||
| is x, then output y is: | |||
| @@ -60,13 +59,3 @@ class Linear(QATModule): | |||
| def forward(self, x): | |||
| return self._calc_linear(x, self.weight, self.bias) | |||
| def forward_qat(self, x): | |||
| w_qat = self.apply_fakequant_with_observer( | |||
| self.weight, self.weight_fake_quant, self.weight_observer | |||
| ) | |||
| return self.apply_fakequant_with_observer( | |||
| self._calc_linear(x, w_qat, self.bias), | |||
| self.act_fake_quant, | |||
| self.act_observer, | |||
| ) | |||
| @@ -7,7 +7,6 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from abc import ABCMeta, abstractmethod | |||
| from collections import OrderedDict | |||
| from enum import Enum | |||
| from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||
| import numpy as np | |||
| @@ -443,98 +442,3 @@ class Module(metaclass=ABCMeta): | |||
| loaded.append(k) | |||
| return set(loaded), set(skipped) | |||
| class QATModule(Module): | |||
| r""" | |||
| Base class of quantization related Module. Add extra forward methods | |||
| :meth:`~.QATModule.forward_qat` and :meth:`~.QATModule.forward_quantized` for | |||
| ``qat``(quantization aware training) mode and ``quantized`` mode respectively. | |||
| Use :meth:`~.QATModule.quant` to switch between ``QAT`` and ``NORMAL`` mode, | |||
| and use :meth:`~.QATModule.to_quantized` to switch to ``quantized`` mode, | |||
| which is irreversible. | |||
| If you want to recursively switch mode for all QATModule in network, use | |||
| functions in :mod:`~.quantization.quantize`. | |||
| """ | |||
| class QATMode(Enum): | |||
| DISABLED = 1 | |||
| QAT = 2 | |||
| CALIBRATION = 3 | |||
| def __init__(self): | |||
| from ..quantization import ( | |||
| QConfig, | |||
| FakeQuantize, | |||
| Observer, | |||
| ) # pylint: disable=all | |||
| super().__init__() | |||
| self.quantizing = self.QATMode.DISABLED | |||
| self.scale = None | |||
| self.weight_observer = None # type: Observer | |||
| self.act_observer = None # type: Observer | |||
| self.weight_fake_quant = None # type: FakeQuantize | |||
| self.act_fake_quant = None # type: FakeQuantize | |||
| def set_qconfig(self, qconfig: "QConfig"): | |||
| self.weight_observer = qconfig.weight_observer() | |||
| self.act_observer = qconfig.act_observer() | |||
| self.weight_fake_quant = ( | |||
| None | |||
| if qconfig.fake_quant is None | |||
| else qconfig.fake_quant(self.weight_observer.dtype) | |||
| ) | |||
| self.act_fake_quant = ( | |||
| None | |||
| if qconfig.fake_quant is None | |||
| else qconfig.fake_quant(self.act_observer.dtype) | |||
| ) | |||
| def apply_observer(self, target: Tensor, obs: "Observer"): | |||
| return obs(target) | |||
| def apply_fakequant_with_observer( | |||
| self, target: Tensor, fq: "FakeQuantize", obs: "Observer" | |||
| ): | |||
| oup = self.apply_observer(target, obs) | |||
| if fq is not None: | |||
| q_dict = obs.get_qparams() | |||
| oup = fq(oup, q_dict) | |||
| return oup | |||
| def set_qat_mode(self, mode: QATMode): | |||
| r""" | |||
| Change ``self.quantizing`` mode, available values: ``self.QATMode.DISABLED``, | |||
| ``QAT``,``CALIBRATION``. | |||
| """ | |||
| if not isinstance(mode, self.QATMode): | |||
| raise TypeError("mode must be QATMode Enum type") | |||
| self.quantizing = mode | |||
| def to_quantized(self): | |||
| r""" | |||
| Return a new :class:`~.Module` with quantized parameters of ``self`` | |||
| according to scale and zero_point in ``self.xxx_observer``. | |||
| """ | |||
| raise NotImplementedError( | |||
| "Use megengine.quantization.quantize to register the method." | |||
| ) | |||
| @abstractmethod | |||
| def forward_qat(self, *args, **kwargs): | |||
| r""" | |||
| Forward method for ``qat`` mode. | |||
| """ | |||
| def __call__(self, *args, **kwargs): | |||
| if self.quantizing == self.QATMode.DISABLED: | |||
| return self.forward(*args, **kwargs) | |||
| else: | |||
| return self.forward_qat(*args, **kwargs) | |||
| @@ -0,0 +1,13 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .concat import Concat | |||
| from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||
| from .elemwise import Elemwise | |||
| from .linear import Linear | |||
| from .module import QATModule | |||
| from .quant_dequant import DequantStub, QuantStub | |||
| @@ -0,0 +1,30 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Iterable | |||
| from ...core.tensor import Tensor | |||
| from .. import concat as Float | |||
| from .module import QATModule | |||
| class Concat(Float.Concat, QATModule): | |||
| r""" | |||
| A :class:`~.QATModule` to do functional concat with QAT support. | |||
| Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
| """ | |||
| def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||
| return self.apply_quant_activation(super().forward(inps, axis)) | |||
| @classmethod | |||
| def from_float_module(cls, float_module): | |||
| r""" | |||
| Return a :class:`~.QATModule` instance converted from | |||
| a float :class:`~.Module` instance. | |||
| """ | |||
| return cls() | |||
| @@ -0,0 +1,193 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ...core import ones, zeros | |||
| from ...functional import add_update, relu, sqrt, sum, zero_grad | |||
| from .. import conv_bn_relu as Float | |||
| from .module import QATModule | |||
| class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): | |||
| def get_batch_mean_var(self, inp): | |||
| def _sum_channel(inp, axis=0, keepdims=True): | |||
| if isinstance(axis, int): | |||
| out = sum(inp, axis=axis, keepdims=keepdims) | |||
| elif isinstance(axis, tuple): | |||
| for idx, elem in enumerate(axis): | |||
| out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) | |||
| return out | |||
| sum1 = _sum_channel(inp, (0, 2, 3)) | |||
| sum2 = _sum_channel(inp ** 2, (0, 2, 3)) | |||
| reduce_size = inp.shapeof().prod() / inp.shapeof(1) | |||
| batch_mean = sum1 / reduce_size | |||
| batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size | |||
| return batch_mean, batch_var | |||
| def fold_weight_bias(self, bn_mean, bn_var): | |||
| # get fold bn conv param | |||
| # bn_istd = 1 / bn_std | |||
| # w_fold = gamma / bn_std * W | |||
| # b_fold = gamma * (b - bn_mean) / bn_std + beta | |||
| gamma = self.bn.weight | |||
| if gamma is None: | |||
| gamma = ones((self.bn.num_features), dtype="float32") | |||
| gamma = gamma.reshape(1, -1, 1, 1) | |||
| beta = self.bn.bias | |||
| if beta is None: | |||
| beta = zeros((self.bn.num_features), dtype="float32") | |||
| beta = beta.reshape(1, -1, 1, 1) | |||
| if bn_mean is None: | |||
| bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||
| if bn_var is None: | |||
| bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||
| conv_bias = self.conv.bias | |||
| if conv_bias is None: | |||
| conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
| # bn_istd = 1 / bn_std | |||
| # w_fold = gamma / bn_std * W | |||
| scale_factor = gamma * bn_istd | |||
| if self.conv.groups == 1: | |||
| w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||
| else: | |||
| w_fold = self.conv.weight * scale_factor.reshape( | |||
| self.conv.groups, -1, 1, 1, 1 | |||
| ) | |||
| # b_fold = gamma * (b - bn_mean) / bn_std + beta | |||
| b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||
| return w_fold, b_fold | |||
| def update_running_mean_and_running_var( | |||
| self, bn_mean, bn_var, num_elements_per_channel | |||
| ): | |||
| # update running mean and running var. no grad, use unbiased bn var | |||
| bn_mean = zero_grad(bn_mean) | |||
| bn_var = ( | |||
| zero_grad(bn_var) | |||
| * num_elements_per_channel | |||
| / (num_elements_per_channel - 1) | |||
| ) | |||
| exponential_average_factor = 1 - self.bn.momentum | |||
| add_update( | |||
| self.bn.running_mean, | |||
| delta=bn_mean, | |||
| alpha=1 - exponential_average_factor, | |||
| beta=exponential_average_factor, | |||
| ) | |||
| add_update( | |||
| self.bn.running_var, | |||
| delta=bn_var, | |||
| alpha=1 - exponential_average_factor, | |||
| beta=exponential_average_factor, | |||
| ) | |||
| def calc_conv_bn_qat(self, inp, approx=True): | |||
| if self.training and not approx: | |||
| conv = self.conv(inp) | |||
| bn_mean, bn_var = self.get_batch_mean_var(conv) | |||
| num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) | |||
| self.update_running_mean_and_running_var( | |||
| bn_mean, bn_var, num_elements_per_channel | |||
| ) | |||
| else: | |||
| bn_mean, bn_var = self.bn.running_mean, self.bn.running_var | |||
| # get gamma and beta in BatchNorm | |||
| gamma = self.bn.weight | |||
| if gamma is None: | |||
| gamma = ones((self.bn.num_features), dtype="float32") | |||
| gamma = gamma.reshape(1, -1, 1, 1) | |||
| beta = self.bn.bias | |||
| if beta is None: | |||
| beta = zeros((self.bn.num_features), dtype="float32") | |||
| beta = beta.reshape(1, -1, 1, 1) | |||
| # conv_bias | |||
| conv_bias = self.conv.bias | |||
| if conv_bias is None: | |||
| conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
| # bn_istd = 1 / bn_std | |||
| # w_fold = gamma / bn_std * W | |||
| scale_factor = gamma * bn_istd | |||
| if self.conv.groups == 1: | |||
| w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||
| else: | |||
| w_fold = self.conv.weight * scale_factor.reshape( | |||
| self.conv.groups, -1, 1, 1, 1 | |||
| ) | |||
| b_fold = None | |||
| if not (self.training and approx): | |||
| # b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta | |||
| b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||
| w_qat = self.apply_quant_weight(w_fold) | |||
| conv = self.conv.calc_conv(inp, w_qat, b_fold) | |||
| if not (self.training and approx): | |||
| return conv | |||
| # rescale conv to get original conv output | |||
| orig_conv = conv / scale_factor.reshape(1, -1, 1, 1) | |||
| if self.conv.bias is not None: | |||
| orig_conv = orig_conv + self.conv.bias | |||
| # calculate batch norm | |||
| bn_mean, bn_var = self.get_batch_mean_var(orig_conv) | |||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
| conv = gamma * bn_istd * (orig_conv - bn_mean) + beta | |||
| num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) | |||
| self.update_running_mean_and_running_var( | |||
| bn_mean, bn_var, num_elements_per_channel | |||
| ) | |||
| return conv | |||
| @classmethod | |||
| def from_float_module(cls, float_module: Float._ConvBnActivation2d): | |||
| r""" | |||
| Return a :class:`~.QATModule` instance converted from | |||
| a float :class:`~.Module` instance. | |||
| """ | |||
| qat_module = cls( | |||
| float_module.conv.in_channels, | |||
| float_module.conv.out_channels, | |||
| float_module.conv.kernel_size, | |||
| float_module.conv.stride, | |||
| float_module.conv.padding, | |||
| float_module.conv.dilation, | |||
| float_module.conv.groups, | |||
| bool(float_module.conv.bias), | |||
| float_module.conv.conv_mode.name, | |||
| float_module.conv.compute_mode.name, | |||
| ) | |||
| qat_module.conv.weight = float_module.conv.weight | |||
| qat_module.conv.bias = float_module.conv.bias | |||
| qat_module.bn = float_module.bn | |||
| return qat_module | |||
| class ConvBn2d(_ConvBnActivation2d): | |||
| r""" | |||
| A fused :class:`~.QATModule` including Conv2d, BatchNorm2d with QAT support. | |||
| Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
| """ | |||
| def forward(self, inp): | |||
| return self.apply_quant_activation(self.calc_conv_bn_qat(inp)) | |||
| class ConvBnRelu2d(_ConvBnActivation2d): | |||
| r""" | |||
| A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu with QAT support. | |||
| Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
| """ | |||
| def forward(self, inp): | |||
| return self.apply_quant_activation(relu(self.calc_conv_bn_qat(inp))) | |||
| @@ -0,0 +1,29 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .. import elemwise as Float | |||
| from .module import QATModule | |||
| class Elemwise(Float.Elemwise, QATModule): | |||
| r""" | |||
| A :class:`~.QATModule` to do elemwise operator with QAT support. | |||
| Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
| :param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail. | |||
| """ | |||
| def forward(self, *inps): | |||
| return self.apply_quant_activation(super().forward(*inps)) | |||
| @classmethod | |||
| def from_float_module(cls, float_module: Float.Elemwise): | |||
| r""" | |||
| Return a :class:`~.QATModule` instance converted from | |||
| a float :class:`~.Module` instance. | |||
| """ | |||
| return cls(float_module.method.name) | |||
| @@ -0,0 +1,37 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .. import linear as Float | |||
| from .module import QATModule | |||
| class Linear(Float.Linear, QATModule): | |||
| r""" | |||
| A :class:`~.QATModule` version of :class:`~.module.linear.Linear`. | |||
| Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
| :param in_features: size of each input sample. | |||
| :param out_features: size of each output sample. | |||
| :param bias: If set to ``False``, the layer will not learn an additive bias. | |||
| Default: ``True`` | |||
| """ | |||
| def forward(self, x): | |||
| w_qat = self.apply_quant_weight(self.weight) | |||
| return self.apply_quant_activation(self._calc_linear(x, w_qat, self.bias),) | |||
| @classmethod | |||
| def from_float_module(cls, float_module: Float.Linear): | |||
| r""" | |||
| Return a :class:`~.QATModule` instance converted from | |||
| a float :class:`~.Module` instance. | |||
| """ | |||
| qmod = cls(float_module.in_features, float_module.out_features) | |||
| qmod.weight = float_module.weight | |||
| qmod.bias = float_module.bias | |||
| return qmod | |||
| @@ -0,0 +1,96 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from abc import abstractmethod | |||
| from ...core import Tensor | |||
| from ...quantization import FakeQuantize, Observer, QConfig | |||
| from ..module import Module | |||
| class QATModule(Module): | |||
| r""" | |||
| Base class of quantized-float related Module, basically for QAT and Calibration. | |||
| Use :meth:`~.QATModule.from_float_module` to generate a instance from float :class:`~.Module`. | |||
| Or use :func:`~.quantize.quantize_qat` to do it recursively and automatically. | |||
| Can also be converted to :class:`~.QuantizedModule` for deployment using | |||
| :func:`~.quantize.quantize` further. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.scale = None | |||
| self.weight_observer = None # type: Observer | |||
| self.act_observer = None # type: Observer | |||
| self.weight_fake_quant = None # type: FakeQuantize | |||
| self.act_fake_quant = None # type: FakeQuantize | |||
| def set_qconfig(self, qconfig: QConfig): | |||
| r""" | |||
| Set quantization related configs with ``qconfig``, including | |||
| observer and fake_quant for weight and activation. | |||
| """ | |||
| self.weight_observer = qconfig.weight_observer() | |||
| self.act_observer = qconfig.act_observer() | |||
| if qconfig.fake_quant is None: | |||
| self.weight_fake_quant = None | |||
| self.act_fake_quant = None | |||
| else: | |||
| self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype) | |||
| self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype) | |||
| def _apply_fakequant_with_observer( | |||
| self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | |||
| ): | |||
| oup = observer(target) | |||
| if fake_quant is None: | |||
| return oup | |||
| else: | |||
| q_dict = observer.get_qparams() | |||
| return fake_quant(oup, q_dict) | |||
| def apply_quant_weight(self, target: Tensor): | |||
| r""" | |||
| Apply weight's observer and fake_quant from ``qconfig`` on ``target``. | |||
| """ | |||
| return self._apply_fakequant_with_observer( | |||
| target, self.weight_fake_quant, self.weight_observer | |||
| ) | |||
| def apply_quant_activation(self, target: Tensor): | |||
| r""" | |||
| Apply weight's observer and fake_quant from ``qconfig`` on ``target``. | |||
| """ | |||
| return self._apply_fakequant_with_observer( | |||
| target, self.act_fake_quant, self.act_observer | |||
| ) | |||
| def get_weight_dtype(self): | |||
| r""" | |||
| Get weight's quantization dtype as the method from ``qconfig``. | |||
| """ | |||
| return self.weight_observer.get_dtype() | |||
| def get_activation_dtype(self): | |||
| r""" | |||
| Get activation's quantization dtype as the method from ``qconfig``. | |||
| """ | |||
| return self.act_observer.get_dtype() | |||
| @classmethod | |||
| @abstractmethod | |||
| def from_float_module(cls, float_module: Module): | |||
| r""" | |||
| Return a :class:`~.QATModule` instance converted from | |||
| a float :class:`~.Module` instance. | |||
| """ | |||
| @@ -0,0 +1,45 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .. import quant_dequant as Float | |||
| from .module import QATModule | |||
| class QuantStub(Float.QuantStub, QATModule): | |||
| r""" | |||
| A helper QATModule simply return input, but will quantize | |||
| input after converted to :class:`~.QuantizedModule`. | |||
| """ | |||
| def forward(self, inp): | |||
| return self.apply_quant_activation(inp) | |||
| @classmethod | |||
| def from_float_module(cls, float_module: Float.QuantStub): | |||
| r""" | |||
| Return a :class:`~.QATModule` instance converted from | |||
| a float :class:`~.Module` instance. | |||
| """ | |||
| return cls() | |||
| class DequantStub(Float.DequantStub, QATModule): | |||
| r""" | |||
| A helper QATModule simply return input, but will de-quantize | |||
| input after converted to :class:`~.QuantizedModule`. | |||
| """ | |||
| def forward(self, inp): | |||
| return inp | |||
| @classmethod | |||
| def from_float_module(cls, float_module: Float.DequantStub): | |||
| r""" | |||
| Return a :class:`~.QATModule` instance converted from | |||
| a float :class:`~.Module` instance. | |||
| """ | |||
| return cls() | |||
| @@ -5,30 +5,24 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .module import QATModule | |||
| from .module import Module | |||
| class QuantStub(QATModule): | |||
| class QuantStub(Module): | |||
| r""" | |||
| A helper QATModule doing quantize operation on input. | |||
| A helper :class:`~.Module` simply returning input. Could be replaced with :class:`~.QATModule` | |||
| version :class:`~.qat.QuantStub` using :func:`~.quantize.quantize_qat`. | |||
| """ | |||
| def forward(self, inp): | |||
| return inp | |||
| def forward_qat(self, inp): | |||
| return self.apply_fakequant_with_observer( | |||
| inp, self.act_fake_quant, self.act_observer | |||
| ) | |||
| class DequantStub(QATModule): | |||
| class DequantStub(Module): | |||
| r""" | |||
| A helper QATModule doing de-quantize operation on input. | |||
| A helper :class:`~.Module` simply returning input. Could be replaced with :class:`~.QATModule` | |||
| version :class:`~.qat.DequantStub` using :func:`~.quantize.quantize_qat`. | |||
| """ | |||
| def forward(self, inp): | |||
| return inp | |||
| def forward_qat(self, inp): | |||
| return inp | |||
| @@ -9,4 +9,5 @@ from .concat import Concat | |||
| from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||
| from .elemwise import Elemwise | |||
| from .linear import Linear | |||
| from .module import QuantizedModule | |||
| from .quant_dequant import DequantStub, QuantStub | |||
| @@ -7,17 +7,15 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Iterable | |||
| from ... import _internal as mgb | |||
| from ... import functional as F | |||
| from ... import module as Float | |||
| from ...core.tensor import Tensor | |||
| from ...quantization.utils import register_method_to_class | |||
| from ..module import Module | |||
| from ..qat import concat as QAT | |||
| from .module import QuantizedModule | |||
| class Concat(Module): | |||
| class Concat(QuantizedModule): | |||
| r""" | |||
| A :class:`~.Module` to do quantized concat, inference only. | |||
| A :class:`~.QuantizedModule` to do quantized concat, inference only. | |||
| """ | |||
| def __init__(self, dtype=None): | |||
| @@ -25,16 +23,13 @@ class Concat(Module): | |||
| self.output_dtype = dtype | |||
| def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||
| if self.training: | |||
| raise ValueError("quantized module only support inference.") | |||
| new_inps = (x.astype(self.output_dtype) for x in inps) | |||
| return F.concat(new_inps, axis) | |||
| @register_method_to_class(Float.Concat) | |||
| def to_quantized(float_module): | |||
| r""" | |||
| Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
| implemented here to avoid circular import. | |||
| """ | |||
| return Concat(float_module.act_observer.get_dtype()) | |||
| @classmethod | |||
| def from_qat_module(cls, qat_module: QAT.Concat): | |||
| r""" | |||
| return a :class:`~.QuantizedModule` instance converted from a | |||
| :class:`~.QATModule` instance. | |||
| """ | |||
| return cls(qat_module.get_activation_dtype()) | |||
| @@ -5,7 +5,6 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from functools import partial | |||
| from typing import Tuple, Union | |||
| import megengine._internal as mgb | |||
| @@ -13,11 +12,11 @@ import megengine._internal as mgb | |||
| from ... import module as Float | |||
| from ...core import Parameter | |||
| from ...functional import conv_bias_activation | |||
| from ...module import Conv2d | |||
| from ...quantization.utils import register_method_to_class | |||
| from ..qat import conv_bn_relu as QAT | |||
| from .module import QuantizedModule | |||
| class _ConvBnActivation2d(Conv2d): | |||
| class _ConvBnActivation2d(Float.Conv2d, QuantizedModule): | |||
| r"""Applies a 2D convolution over an quantized input tensor, inference only. | |||
| The parameter is same with :class: `~.Conv2d` | |||
| @@ -68,44 +67,41 @@ class _ConvBnActivation2d(Conv2d): | |||
| nonlinear_mode=nonlinear_mode, | |||
| ) | |||
| @classmethod | |||
| def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d): | |||
| r""" | |||
| return a :class:`~.QuantizedModule` instance converted from a | |||
| :class:`~.QATModule` instance. | |||
| """ | |||
| output_dtype = qat_module.get_activation_dtype() | |||
| qconv = cls( | |||
| qat_module.conv.in_channels, | |||
| qat_module.conv.out_channels, | |||
| qat_module.conv.kernel_size, | |||
| qat_module.conv.stride, | |||
| qat_module.conv.padding, | |||
| qat_module.conv.dilation, | |||
| qat_module.conv.groups, | |||
| dtype=output_dtype, | |||
| ) | |||
| w_fold, b_fold = qat_module.fold_weight_bias( | |||
| qat_module.bn.running_mean, qat_module.bn.running_var | |||
| ) | |||
| weight = w_fold.astype(qat_module.get_weight_dtype()) | |||
| qconv.weight = Parameter(weight.numpy()) | |||
| qconv.bias = Parameter(b_fold.numpy()) | |||
| return qconv | |||
| class ConvBn2d(_ConvBnActivation2d): | |||
| r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBn2d`.""" | |||
| def forward(self, inp): | |||
| if self.training: | |||
| raise ValueError("quantized module only support inference.") | |||
| return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") | |||
| class ConvBnRelu2d(_ConvBnActivation2d): | |||
| r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBnRelu2d`.""" | |||
| def forward(self, inp): | |||
| if self.training: | |||
| raise ValueError("quantized module only support inference.") | |||
| return self.calc_conv_quantized(inp, nonlinear_mode="RELU") | |||
| def to_quantized(quantized_class, float_module): | |||
| output_dtype = float_module.act_observer.get_dtype() | |||
| qconv = quantized_class( | |||
| float_module.conv.in_channels, | |||
| float_module.conv.out_channels, | |||
| float_module.conv.kernel_size, | |||
| float_module.conv.stride, | |||
| float_module.conv.padding, | |||
| float_module.conv.dilation, | |||
| float_module.conv.groups, | |||
| dtype=output_dtype, | |||
| ) | |||
| w_fold, b_fold = float_module.fold_weight_bias( | |||
| float_module.bn.running_mean, float_module.bn.running_var | |||
| ) | |||
| weight = w_fold.astype(float_module.weight_observer.get_dtype()) | |||
| qconv.weight = Parameter(weight.numpy()) | |||
| qconv.bias = Parameter(b_fold.numpy()) | |||
| return qconv | |||
| # replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
| # implemented here to avoid circular import. | |||
| register_method_to_class(Float.ConvBn2d)(partial(to_quantized, ConvBn2d)) | |||
| register_method_to_class(Float.ConvBnRelu2d)(partial(to_quantized, ConvBnRelu2d)) | |||
| @@ -6,11 +6,10 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ... import _internal as mgb | |||
| from ... import module as Float | |||
| from ...core import Tensor, wrap_io_tensor | |||
| from ...core.graph import _use_default_if_none | |||
| from ...quantization.utils import register_method_to_class | |||
| from ..module import Module | |||
| from ..qat import elemwise as QAT | |||
| from .module import QuantizedModule | |||
| @wrap_io_tensor | |||
| @@ -24,13 +23,8 @@ def _elemwise_multi_type(mode, *inputs, **kwargs) -> Tensor: | |||
| return mgb.opr.elemwise_multi_type(*inputs, mode=mode, **kwargs) | |||
| class Elemwise(Module): | |||
| r""" | |||
| quantized module for elemwise operator, inference only. | |||
| :param method: the elemwise method, supported string refer to :class:`~.module.elemwise.Elemwise`. | |||
| it will do quantized operator with specified output quantized dtype. | |||
| """ | |||
| class Elemwise(QuantizedModule): | |||
| r"""quantized version of :class:`~.qat.elemwise.Elemwise`.""" | |||
| _elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode | |||
| @@ -44,11 +38,10 @@ class Elemwise(Module): | |||
| raise ValueError("quantized module only support inference.") | |||
| return _elemwise_multi_type(self.method, *inps, dtype=self.output_dtype) | |||
| @register_method_to_class(Float.Elemwise) | |||
| def to_quantized(float_module): | |||
| r""" | |||
| Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
| implemented here to avoid circular import. | |||
| """ | |||
| return Elemwise(float_module.method.name, float_module.act_observer.get_dtype()) | |||
| @classmethod | |||
| def from_qat_module(cls, qat_module: QAT.Elemwise): | |||
| r""" | |||
| return a :class:`~.QuantizedModule` instance converted from a | |||
| :class:`~.QATModule` instance. | |||
| """ | |||
| return cls(qat_module.method.name, qat_module.get_activation_dtype()) | |||
| @@ -10,19 +10,13 @@ import numpy as np | |||
| import megengine._internal as mgb | |||
| from ... import functional as F | |||
| from ... import module as Float | |||
| from ...core import Parameter | |||
| from ...quantization.utils import register_method_to_class | |||
| from ..module import Module | |||
| from ..qat import linear as QAT | |||
| from .module import QuantizedModule | |||
| class Linear(Module): | |||
| r"""Applies a quantized linear transformation to the input. The module | |||
| usually convert from QAT module by to_quantized method. | |||
| :param dtype: output data type. | |||
| """ | |||
| class Linear(QuantizedModule): | |||
| r"""quantized version of :class:`~.qat.linear.Linear`.""" | |||
| def __init__( | |||
| self, dtype: np.dtype = None, | |||
| @@ -44,17 +38,16 @@ class Linear(Module): | |||
| None if self.bias is None else self.bias.astype(bias_dtype), | |||
| ).astype(self.output_dtype) | |||
| @register_method_to_class(Float.Linear) | |||
| def to_quantized(float_module): | |||
| r""" | |||
| Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
| implemented here to avoid circular import. | |||
| """ | |||
| output_dtype = float_module.act_observer.get_dtype() | |||
| qmod = Linear(dtype=output_dtype,) | |||
| weight = float_module.weight.astype(float_module.weight_observer.get_dtype()) | |||
| qmod.weight = Parameter(weight.numpy()) | |||
| if float_module.bias is not None: | |||
| qmod.bias = Parameter(float_module.bias.numpy()) | |||
| return qmod | |||
| @classmethod | |||
| def from_qat_module(cls, qat_module: QAT.Linear): | |||
| r""" | |||
| return a :class:`~.QuantizedModule` instance converted from a | |||
| :class:`~.QATModule` instance. | |||
| """ | |||
| output_dtype = qat_module.get_activation_dtype() | |||
| qmod = cls(dtype=output_dtype) | |||
| weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | |||
| qmod.weight = Parameter(weight.numpy()) | |||
| if qat_module.bias is not None: | |||
| qmod.bias = Parameter(qat_module.bias.numpy()) | |||
| return qmod | |||
| @@ -0,0 +1,31 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from abc import abstractmethod | |||
| from ..module import Module | |||
| from ..qat import QATModule | |||
| class QuantizedModule(Module): | |||
| r""" | |||
| Base class of quantized Module, which should be converted from QATModule | |||
| and not support traning. | |||
| """ | |||
| def __call__(self, *inputs, **kwargs): | |||
| if self.training: | |||
| raise ValueError("quantized module only support inference.") | |||
| return super().__call__(*inputs, **kwargs) | |||
| @classmethod | |||
| @abstractmethod | |||
| def from_qat_module(cls, qat_module: QATModule): | |||
| r""" | |||
| return a :class:`~.QuantizedModule` instance converted from a | |||
| :class:`~.QATModule` instance. | |||
| """ | |||
| @@ -5,15 +5,14 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ... import _internal as mgb | |||
| from ... import module as Float | |||
| from ...quantization.utils import register_method_to_class | |||
| from ..module import Module | |||
| from ..qat import quant_dequant as QAT | |||
| from .module import QuantizedModule | |||
| class QuantStub(Module): | |||
| class QuantStub(QuantizedModule): | |||
| r""" | |||
| A helper quantize operation on input and inference only. | |||
| quantized version of :class:`~.qat.quant_dequant.QuantStub`, | |||
| will convert input to quantized dtype. | |||
| """ | |||
| def __init__(self, dtype=None): | |||
| @@ -21,35 +20,30 @@ class QuantStub(Module): | |||
| self.output_dtype = dtype | |||
| def forward(self, inp): | |||
| if self.training: | |||
| raise ValueError("quantized module only support inference.") | |||
| return inp.astype(self.output_dtype) | |||
| @classmethod | |||
| def from_qat_module(cls, qat_module: QAT.QuantStub): | |||
| r""" | |||
| return a :class:`~.QuantizedModule` instance converted from a | |||
| :class:`~.QATModule` instance. | |||
| """ | |||
| return cls(qat_module.get_activation_dtype()) | |||
| class DequantStub(Module): | |||
| class DequantStub(QuantizedModule): | |||
| r""" | |||
| A helper de-quantize operation and inference only. | |||
| quantized version of :class:`~.qat.quant_dequant.DequantStub`, | |||
| will restore quantized input to float32 dtype. | |||
| """ | |||
| def forward(self, inp): | |||
| if self.training: | |||
| raise ValueError("quantized module only support inference.") | |||
| return inp.astype("float32") | |||
| @register_method_to_class(Float.QuantStub) | |||
| def to_quantized(float_module): | |||
| r""" | |||
| Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
| implemented here to avoid circular import. | |||
| """ | |||
| return QuantStub(float_module.act_observer.get_dtype()) | |||
| @register_method_to_class(Float.DequantStub) | |||
| def to_quantized(float_module): | |||
| r""" | |||
| Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
| implemented here to avoid circular import. | |||
| """ | |||
| return DequantStub() | |||
| @classmethod | |||
| def from_qat_module(cls, qat_module: QAT.DequantStub): | |||
| r""" | |||
| return a :class:`~.QuantizedModule` instance converted from a | |||
| :class:`~.QATModule` instance. | |||
| """ | |||
| return cls() | |||
| @@ -13,12 +13,3 @@ from .qconfig import ( | |||
| ema_fakequant_qconfig, | |||
| min_max_fakequant_qconfig, | |||
| ) | |||
| from .quantize import ( | |||
| disable_fake_quant, | |||
| disable_observer, | |||
| enable_fake_quant, | |||
| enable_observer, | |||
| quantize, | |||
| quantize_calibration, | |||
| quantize_qat, | |||
| ) | |||
| @@ -15,16 +15,12 @@ from .observer import ( | |||
| class QConfig: | |||
| """ | |||
| r""" | |||
| A config class indicating how to do quantize toward :class:`~.QATModule`'s | |||
| ``activation`` and ``weight``. | |||
| And ``fake_quant`` parameter to indicate | |||
| See :meth:`~.QATModule.set_qconfig` for detail usage. | |||
| ``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage. | |||
| :param weight_observer: interface to instantiate an :class:`~.Observer` indicating | |||
| - how to collect scales and zero_point of wegiht. | |||
| how to collect scales and zero_point of wegiht. | |||
| :param act_observer: similar to ``weight_observer`` but toward activation. | |||
| :param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating | |||
| how to do fake_quant calculation. can be invoked multi times to get different | |||
| @@ -6,68 +6,125 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from copy import deepcopy | |||
| from ..module import Module, QATModule, Sequential, quantized | |||
| from typing import Dict, Tuple | |||
| from .. import module as Float | |||
| from ..module import Module | |||
| from ..module import qat as QAT | |||
| from ..module import quantized as Quantized | |||
| from ..module.qat import QATModule | |||
| from ..module.quantized import QuantizedModule | |||
| from .qconfig import QConfig, ema_fakequant_qconfig | |||
| def _get_quantable_module_names(): | |||
| def is_quantable(key: str): | |||
| value = getattr(Quantized, key) | |||
| return ( | |||
| isinstance(value, type) | |||
| and issubclass(value, QuantizedModule) | |||
| and value != QuantizedModule | |||
| ) | |||
| # source should have all quantable modules' names | |||
| quantable_module_names = [key for key in dir(Quantized) if is_quantable(key)] | |||
| return quantable_module_names | |||
| def _get_convert_dict() -> Tuple[ | |||
| Dict[Module, QATModule], Dict[QATModule, QuantizedModule] | |||
| ]: | |||
| quantable_module_names = _get_quantable_module_names() | |||
| quantable_modules = [getattr(Float, key) for key in quantable_module_names] | |||
| qat_modules = [getattr(QAT, key) for key in quantable_module_names] | |||
| quantized_modules = [getattr(Quantized, key) for key in quantable_module_names] | |||
| float2qat_dict = dict(zip(quantable_modules, qat_modules)) | |||
| qat2quantized_dict = dict(zip(qat_modules, quantized_modules)) | |||
| return float2qat_dict, qat2quantized_dict | |||
| _float2qat_dict, _qat2quantized_dict = _get_convert_dict() | |||
| def quantize(module: Module, inplace=True): | |||
| r""" | |||
| Recursively convert `module` to `quantized` mode through :meth:`~.Module.apply`. | |||
| Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule` | |||
| through :meth:`~.Module.apply`. | |||
| :param module: root module to do convert recursively. | |||
| :param inplace: whether to convert submodules in-place. | |||
| """ | |||
| if not inplace: | |||
| module = deepcopy(module) | |||
| def is_qat_module(obj): | |||
| return isinstance(obj, QATModule) | |||
| qat_modules = tuple(_qat2quantized_dict.keys()) | |||
| def is_qat(mod: Module): | |||
| return isinstance(mod, qat_modules) | |||
| # no need to pass prefix and get pure key of parent Module. | |||
| for key, submodule, parent in module._flatten( | |||
| with_key=True, with_parent=True, predicate=is_qat_module | |||
| with_key=True, with_parent=True, predicate=is_qat | |||
| ): | |||
| if isinstance(parent, Sequential): | |||
| new_mod = _qat2quantized_dict[type(submodule)].from_qat_module(submodule) | |||
| if isinstance(parent, Float.Sequential): | |||
| # cannnot use setattr to be compatible with Sequential's ``__setitem__`` | |||
| parent[int(key.split(".")[-1])] = submodule.to_quantized() | |||
| parent[int(key.split(".")[-1])] = new_mod | |||
| else: | |||
| setattr(parent, key.split(".")[-1], submodule.to_quantized()) | |||
| setattr(parent, key.split(".")[-1], new_mod) | |||
| return module | |||
| def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): | |||
| def quantize_qat( | |||
| module: Module, inplace=True, qconfig: QConfig = ema_fakequant_qconfig | |||
| ): | |||
| r""" | |||
| Recursively convert `module` to `qat` mode through :meth:`~.Module.apply` | |||
| and set qconfig relatively. | |||
| Recursively convert float :class:`~.Module` to :class:`~.QATModule` | |||
| through :meth:`~.Module.apply` and set qconfig relatively. | |||
| :param module: root module to do convert recursively. | |||
| :param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. | |||
| default is :any:`~.qconfig.ema_fakequant_qconfig`. | |||
| :param inplace: whether to convert submodules in-place. | |||
| :param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig. | |||
| default is ``ema_fakequant_qconfig``. | |||
| """ | |||
| def fn(mod: Module): | |||
| if isinstance(mod, QATModule): | |||
| mod.set_qat_mode(QATModule.QATMode.QAT) | |||
| mod.set_qconfig(qconfig) | |||
| if not inplace: | |||
| module = deepcopy(module) | |||
| module.apply(fn) | |||
| quantable_modules = tuple(_float2qat_dict.keys()) | |||
| def is_quantable(mod: Module): | |||
| return isinstance(mod, quantable_modules) | |||
| # no need to pass prefix and get pure key of parent Module. | |||
| for key, submodule, parent in module._flatten( | |||
| with_key=True, with_parent=True, predicate=is_quantable | |||
| ): | |||
| new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule) | |||
| if isinstance(parent, Float.Sequential): | |||
| # cannnot use setattr to be compatible with Sequential's ``__setitem__`` | |||
| parent[int(key.split(".")[-1])] = new_mod | |||
| else: | |||
| setattr(parent, key.split(".")[-1], new_mod) | |||
| propagate_qconfig(module, qconfig) | |||
| return module | |||
| def quantize_calibration(module: Module, qconfig: QConfig = ema_fakequant_qconfig): | |||
| def propagate_qconfig(module: QATModule, qconfig: QConfig): | |||
| r""" | |||
| Recursively convert `module` to `calibration` mode through :meth:`~.Module.apply` | |||
| and set qconfig relatively. | |||
| Recursively set ``module``'s qconfig through :meth:`~.Module.apply`. | |||
| :param module: root module to do convert recursively. | |||
| :param module: root module to traverse recursively. | |||
| :param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. | |||
| default is :any:`~.qconfig.ema_fakequant_qconfig`. | |||
| """ | |||
| def fn(mod: Module): | |||
| if isinstance(mod, QATModule): | |||
| mod.set_qat_mode(QATModule.QATMode.CALIBRATION) | |||
| mod.set_qconfig(qconfig) | |||
| module.apply(fn) | |||
| @@ -5,8 +5,7 @@ import numpy as np | |||
| from megengine import tensor | |||
| from megengine.module import ConvBn2d | |||
| from megengine.quantization import quantize_qat | |||
| from megengine.quantization.quantize import disable_fake_quant | |||
| from megengine.quantization.quantize import disable_fake_quant, quantize_qat | |||
| from megengine.test import assertTensorClose | |||
| @@ -14,18 +13,17 @@ def test_convbn2d(): | |||
| in_channels = 32 | |||
| out_channels = 64 | |||
| kernel_size = 3 | |||
| module = ConvBn2d(in_channels, out_channels, kernel_size) | |||
| quantize_qat(module) | |||
| for groups, bias in product([1, 4], [True, False]): | |||
| inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
| module = ConvBn2d( | |||
| in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||
| ) | |||
| module.train() | |||
| qat_module = copy.deepcopy(module) | |||
| qat_module = quantize_qat(module, inplace=False) | |||
| disable_fake_quant(qat_module) | |||
| normal_outputs = module.forward(inputs) | |||
| qat_outputs = qat_module.forward_qat(inputs) | |||
| inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
| normal_outputs = module(inputs) | |||
| qat_outputs = qat_module(inputs) | |||
| assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) | |||
| a = module.bn.running_mean.numpy() | |||
| b = qat_module.bn.running_mean.numpy() | |||
| assertTensorClose( | |||
| module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8 | |||
| ) | |||
| @@ -33,7 +31,7 @@ def test_convbn2d(): | |||
| module.bn.running_var, qat_module.bn.running_var, max_err=5e-7 | |||
| ) | |||
| module.eval() | |||
| normal_outputs = module.forward(inputs) | |||
| normal_outputs = module(inputs) | |||
| qat_module.eval() | |||
| qat_outputs = qat_module.forward_qat(inputs) | |||
| qat_outputs = qat_module(inputs) | |||
| assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) | |||
| @@ -0,0 +1,38 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from megengine import module as Float | |||
| from megengine.module import qat as QAT | |||
| from megengine.quantization.quantize import _get_quantable_module_names | |||
| def test_get_quantable_module_names(): | |||
| # need to make sure names from Quantized and QAT are the same | |||
| def _get_qat_module_names(): | |||
| def is_qat(key: str): | |||
| value = getattr(QAT, key) | |||
| return ( | |||
| isinstance(value, type) | |||
| and issubclass(value, QAT.QATModule) | |||
| and value != QAT.QATModule | |||
| ) | |||
| # source should have all quantable modules' names | |||
| quantable_module_names = [key for key in dir(QAT) if is_qat(key)] | |||
| return quantable_module_names | |||
| qat_module_names = _get_qat_module_names() | |||
| quantized_module_names = _get_quantable_module_names() | |||
| assert set(qat_module_names) == set(quantized_module_names) | |||
| for key in qat_module_names: | |||
| value = getattr(Float, key) | |||
| assert ( | |||
| isinstance(value, type) | |||
| and issubclass(value, Float.Module) | |||
| and value != Float.Module | |||
| ) | |||