GitOrigin-RevId: 4fe1233ec3
tags/v0.4.0
| @@ -74,12 +74,14 @@ from .nn import ( | |||
| softmax, | |||
| warp_perspective, | |||
| ) | |||
| from .quantized import conv_bias_activation | |||
| from .sort import argsort, sort, top_k | |||
| from .tensor import ( | |||
| add_axis, | |||
| arange, | |||
| broadcast_to, | |||
| concat, | |||
| cond_take, | |||
| dimshuffle, | |||
| gather, | |||
| linspace, | |||
| @@ -0,0 +1,84 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # pylint: disable=too-many-lines | |||
| from typing import Tuple, Union | |||
| from .. import _internal as mgb | |||
| from ..core import Tensor, wrap_io_tensor | |||
| from ..utils.types import _pair, _pair_nonzero | |||
| from .debug_param import get_conv_execution_strategy | |||
| @wrap_io_tensor | |||
| def conv_bias_activation( | |||
| inp: Tensor, | |||
| weight: Tensor, | |||
| bias: Tensor, | |||
| dtype=None, | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| nonlinear_mode="IDENTITY", | |||
| conv_mode="CROSS_CORRELATION", | |||
| compute_mode="DEFAULT", | |||
| ) -> Tensor: | |||
| """ convolution bias with activation operation, only for inference. | |||
| :param inp: The feature map of the convolution operation | |||
| :param weight: The convolution kernel | |||
| :param bias: The bias added to the result of convolution | |||
| :param stride: Stride of the 2D convolution operation. Default: 1 | |||
| :param padding: Size of the paddings added to the input on both sides of its | |||
| spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| :param dilation: Dilation of the 2D convolution operation. Default: 1 | |||
| :param groups: number of groups to divide input and output channels into, | |||
| so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
| and the shape of weight should be ``(groups, out_channel // groups, | |||
| in_channels // groups, height, width)``. | |||
| :type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode` | |||
| :param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
| 'CROSS_CORRELATION'. | |||
| :param dtype: Support for np.dtype, Default: | |||
| np.int8. | |||
| :param scale: scale if use quantization, Default: | |||
| 0.0. | |||
| :param zero_point: scale if use quantization quint8, Default: | |||
| 0.0. | |||
| :type compute_mode: string or | |||
| :class:`mgb.opr_param_defs.Convolution.ComputeMode` | |||
| :param compute_mode: When set to 'DEFAULT', no special requirements will be | |||
| placed on the precision of intermediate results. When set to 'FLOAT32', | |||
| Float32 would be used for accumulator and intermediate result, but only | |||
| effective when input and output are of Float16 dtype. | |||
| """ | |||
| ph, pw = _pair(padding) | |||
| sh, sw = _pair_nonzero(stride) | |||
| dh, dw = _pair_nonzero(dilation) | |||
| sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
| res = mgb.opr.conv_bias_activation( | |||
| inp, | |||
| weight, | |||
| bias, | |||
| compute_mode=compute_mode, | |||
| dtype=dtype, | |||
| strategy=get_conv_execution_strategy(), | |||
| nonlineMode=nonlinear_mode, | |||
| sparse=sparse_type, | |||
| format="NCHW", | |||
| pad_h=ph, | |||
| pad_w=pw, | |||
| stride_h=sh, | |||
| stride_w=sw, | |||
| dilate_h=dh, | |||
| dilate_w=dw, | |||
| mode=conv_mode, | |||
| ) | |||
| return res | |||
| @@ -359,6 +359,41 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: | |||
| return out | |||
| @wrap_io_tensor | |||
| def cond_take(mask: Tensor, x: Tensor, val=1) -> Tensor: | |||
| r""" | |||
| Take elements from data if specific condition is satisfied on mask. This operator has two outputs: the first is the elements taken, and the second is the indices corresponding to those elements; they are both 1-dimensional. High-dimension input would first be flattened. | |||
| :param mask: condition param; must be the same shape with data | |||
| :param x: input tensor from which to take elements | |||
| :param val: value to be compared to by mode | |||
| Examples: | |||
| .. testcode:: | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32)) | |||
| x = tensor(np.array([[1, np.inf], [np.nan, 4]], | |||
| dtype=np.float32)) | |||
| v, index = F.cond_take(mask, x, 1) | |||
| print(v, index) | |||
| Outputs: | |||
| .. testoutput:: | |||
| Tensor([1. 4.]) Tensor([0 3], dtype=int32) | |||
| """ | |||
| v, index = mgb.opr.cond_take( | |||
| x, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=val | |||
| ) | |||
| return v, index | |||
| def shapeof(x: Tensor, axis=None): | |||
| r""" | |||
| The shape of input tensor. | |||
| @@ -8,12 +8,16 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | |||
| from .batchnorm import BatchNorm1d, BatchNorm2d | |||
| from .concat import Concat | |||
| from .conv import Conv2d, ConvTranspose2d | |||
| from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||
| from .dropout import Dropout | |||
| from .elemwise import Elemwise | |||
| from .embedding import Embedding | |||
| from .identity import Identity | |||
| from .linear import Linear | |||
| from .module import Module | |||
| from .module import Module, QATModule | |||
| from .parampack import ParamPack | |||
| from .pooling import AvgPool2d, MaxPool2d | |||
| from .quant_dequant import DequantStub, QuantStub | |||
| from .sequential import Sequential | |||
| @@ -0,0 +1,27 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Iterable | |||
| from .. import functional as F | |||
| from ..core.tensor import Tensor | |||
| from .module import QATModule | |||
| class Concat(QATModule): | |||
| r""" | |||
| A :class:`~.QATModule` to do functional concat, should replace concat with this module, | |||
| supporting ``qat`` mode and ``quantized`` mode. | |||
| """ | |||
| def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||
| return F.concat(inps, axis) | |||
| def forward_qat(self, inps: Iterable[Tensor], axis: int = 0): | |||
| return self.apply_fakequant_with_observer( | |||
| self.forward(inps, axis), self.act_fake_quant, self.act_observer | |||
| ) | |||
| @@ -182,11 +182,11 @@ class Conv2d(_ConvNd): | |||
| # Assume format is NCHW | |||
| return (1, self.out_channels, 1, 1) | |||
| def forward(self, inp): | |||
| def calc_conv(self, inp, weight, bias): | |||
| return conv2d( | |||
| inp, | |||
| self.weight, | |||
| self.bias, | |||
| weight, | |||
| bias, | |||
| self.stride, | |||
| self.padding, | |||
| self.dilation, | |||
| @@ -195,6 +195,9 @@ class Conv2d(_ConvNd): | |||
| self.compute_mode, | |||
| ) | |||
| def forward(self, inp): | |||
| return self.calc_conv(inp, self.weight, self.bias) | |||
| class ConvTranspose2d(_ConvNd): | |||
| r"""Applies a 2D transposed convolution over an input tensor. | |||
| @@ -0,0 +1,168 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Tuple, Union | |||
| from ..core import ones, zeros | |||
| from ..functional import flatten, relu, sqrt, sum | |||
| from .batchnorm import BatchNorm2d | |||
| from .conv import Conv2d | |||
| from .module import QATModule | |||
| class _ConvBn2d(QATModule): | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| bias: bool = True, | |||
| conv_mode: str = "CROSS_CORRELATION", | |||
| compute_mode: str = "DEFAULT", | |||
| eps=1e-5, | |||
| momentum=0.9, | |||
| affine=True, | |||
| track_running_stats=True, | |||
| freeze_bn=False, | |||
| ): | |||
| super().__init__() | |||
| self.conv = Conv2d( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| dilation, | |||
| groups, | |||
| bias, | |||
| conv_mode, | |||
| compute_mode, | |||
| ) | |||
| self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | |||
| self.freeze_bn = freeze_bn | |||
| def update_bn_stats(self): | |||
| self.freeze_bn = False | |||
| return self | |||
| def freeze_bn_stats(self): | |||
| self.freeze_bn = True | |||
| return self | |||
| def get_bn_gamma_beta(self): | |||
| if self.bn.weight is None: | |||
| gamma = ones((self.bn.num_features), dtype="float32") | |||
| else: | |||
| gamma = self.bn.weight | |||
| if self.bn.bias is None: | |||
| beta = zeros((self.bn.num_features), dtype="float32") | |||
| else: | |||
| beta = self.bn.bias | |||
| return gamma, beta | |||
| def get_batch_mean_var(self, inp): | |||
| def _sum_channel(inp, axis=0, keepdims=True): | |||
| if isinstance(axis, int): | |||
| out = sum(inp, axis=axis, keepdims=keepdims) | |||
| elif isinstance(axis, tuple): | |||
| for idx, elem in enumerate(axis): | |||
| out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) | |||
| return out | |||
| sum1 = _sum_channel(inp, (0, 2, 3)) | |||
| sum2 = _sum_channel(inp ** 2, (0, 2, 3)) | |||
| reduce_size = inp.shapeof().prod() / inp.shapeof(1) | |||
| batch_mean = sum1 / reduce_size | |||
| batch_var = (sum2 - sum1 ** 2 / reduce_size) / (reduce_size - 1) | |||
| return batch_mean, batch_var | |||
| def fold_weight_bias(self, bn_mean, bn_var): | |||
| # get fold bn conv param | |||
| # bn_istd = 1 / bn_std | |||
| # w_fold = gamma / bn_std * W | |||
| # b_fold = gamma * (b - bn_mean) / bn_std + beta | |||
| gamma, beta = self.get_bn_gamma_beta() | |||
| b = self.conv.bias | |||
| if b is None: | |||
| b = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||
| if bn_mean is None: | |||
| bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||
| if bn_var is None: | |||
| bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
| if self.conv.groups == 1: | |||
| w_fold = ( | |||
| self.conv.weight | |||
| * gamma.reshape(-1, 1, 1, 1) | |||
| * bn_istd.reshape(-1, 1, 1, 1) | |||
| ) | |||
| else: | |||
| w_fold = ( | |||
| self.conv.weight | |||
| * gamma.reshape(self.conv.groups, -1, 1, 1, 1) | |||
| * bn_istd.reshape(self.conv.groups, -1, 1, 1, 1) | |||
| ) | |||
| b_fold = flatten(beta) + ( | |||
| flatten(gamma) * (flatten(b) - flatten(bn_mean)) * flatten(bn_istd) | |||
| ) | |||
| b_fold = b_fold.reshape(self.conv._infer_bias_shape()) | |||
| return w_fold, b_fold | |||
| def calc_conv_bn_qat(self, inp): | |||
| # TODO: use pytorch method as | |||
| conv = self.conv(inp) | |||
| self.bn(conv) | |||
| if self.training: | |||
| bn_mean, bn_var = self.get_batch_mean_var(conv) | |||
| else: | |||
| bn_mean, bn_var = self.bn.running_mean, self.bn.running_var | |||
| w_fold, b_fold = self.fold_weight_bias(bn_mean, bn_var) | |||
| w_qat = self.apply_fakequant_with_observer( | |||
| w_fold, self.weight_fake_quant, self.weight_observer | |||
| ) | |||
| return self.conv.calc_conv(inp, w_qat, b_fold) | |||
| class ConvBn2d(_ConvBn2d): | |||
| r""" | |||
| A fused :class:`~.QATModule` including Conv2d and BatchNorm2d, supporting ``qat`` mode | |||
| and ``normal`` mode. | |||
| """ | |||
| def forward_qat(self, inp): | |||
| return self.apply_fakequant_with_observer( | |||
| self.calc_conv_bn_qat(inp), self.act_fake_quant, self.act_observer | |||
| ) | |||
| def forward(self, inp): | |||
| return self.bn(self.conv(inp)) | |||
| class ConvBnRelu2d(_ConvBn2d): | |||
| r""" | |||
| A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu, supporting ``qat`` | |||
| mode and ``normal`` mode. | |||
| """ | |||
| def forward_qat(self, inp): | |||
| return self.apply_fakequant_with_observer( | |||
| relu(self.calc_conv_bn_qat(inp)), self.act_fake_quant, self.act_observer | |||
| ) | |||
| def forward(self, inp): | |||
| return relu(self.bn(self.conv(inp))) | |||
| @@ -0,0 +1,95 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .. import _internal as mgb | |||
| from ..core import Tensor, wrap_io_tensor | |||
| from ..core.graph import _use_default_if_none | |||
| from .module import QATModule | |||
| @wrap_io_tensor | |||
| def _elemwise_func(mode, *inputs, **kwargs) -> Tensor: | |||
| if all(isinstance(i, (int, float)) for i in inputs): | |||
| device, comp_graph = _use_default_if_none(None, None) | |||
| ret = mgb.opr.elemwise( | |||
| *inputs, mode=mode, comp_node=device, comp_graph=comp_graph, **kwargs | |||
| ) | |||
| return ret.inferred_value[0] | |||
| return mgb.opr.elemwise(*inputs, mode=mode, **kwargs) | |||
| class Elemwise(QATModule): | |||
| r""" | |||
| A :class:`~.QATModule` to do elemwise operator, should functional operator with this module, | |||
| supporting ``qat`` mode and ``normal`` mode. | |||
| :param method: the elemwise method, support the following string. | |||
| It will do the normal elemwise operator for float. | |||
| * "ADD": a + b | |||
| * "FUSE_ADD_RELU": max(x+y, 0) | |||
| * "MUL": x * y | |||
| * "MIN": min(x, y) | |||
| * "MAX": max(x, y) | |||
| * "SUB": x - y | |||
| * "TRUE_DIV": x / y | |||
| * "FUSE_ADD_SIGMOID": sigmoid(x + y) | |||
| * "FUSE_ADD_TANH": tanh(x + y) | |||
| * "RELU": x > 0 ? x : 0 | |||
| * "ABS": x > 0 ? x : -x | |||
| * "SIGMOID": sigmoid(x) | |||
| * "EXP": exp(x) | |||
| * "TANH": tanh(x) | |||
| * "FUSE_MUL_ADD3": x * y + z | |||
| * "FAST_TANH": fast_tanh(x) | |||
| * "NEGATE": -x | |||
| * "ACOS": acos(x) | |||
| * "ASIN": asin(x) | |||
| * "CEIL": ceil(x) | |||
| * "COS": cos(x) | |||
| * "EXPM1": expm1(x) | |||
| * "FLOOR": floor(x) | |||
| * "LOG": log(x) | |||
| * "LOG1P": log1p(x) | |||
| * "SIN": sin(x) | |||
| * "ROUND": round(x) | |||
| * "ERF": erf(x) | |||
| * "ERFINV": erfinv(x) | |||
| * "ERFC": erfc(x) | |||
| * "ERFCINV": erfcinv(x) | |||
| * "ABS_GRAD": abs_grad | |||
| * "FLOOR_DIV": floor_div | |||
| * "MOD": mod | |||
| * "SIGMOID_GRAD": sigmoid_grad | |||
| * "SWITCH_GT0": switch_gt0 | |||
| * "TANH_GRAD": tanh_grad | |||
| * "LT": lt | |||
| * "LEQ": leq | |||
| * "EQ": eq | |||
| * "POW": pow | |||
| * "LOG_SUM_EXP": log_sum_exp | |||
| * "FAST_TANH_GRAD": fast_tanh_grad | |||
| * "ATAN2": atan2 | |||
| * "COND_LEQ_MOV": cond_leq_mov | |||
| * "H_SWISH": h_swish | |||
| * "FUSE_ADD_H_SWISH": h_swish(x+y) | |||
| * "H_SWISH_GRAD": h_swish_grad | |||
| """ | |||
| _elemwise_mode_type = mgb.opr_param_defs.Elemwise.Mode | |||
| def __init__(self, method): | |||
| super().__init__() | |||
| self.method = self._elemwise_mode_type.convert(method) | |||
| def forward(self, *inps): | |||
| return _elemwise_func(self.method, *inps) | |||
| def forward_qat(self, *inps): | |||
| return self.apply_fakequant_with_observer( | |||
| self.forward(*inps), self.act_fake_quant, self.act_observer, | |||
| ) | |||
| @@ -1,4 +1,3 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| @@ -8,6 +7,7 @@ | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from abc import ABCMeta, abstractmethod | |||
| from collections import OrderedDict | |||
| from enum import Enum | |||
| from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||
| import numpy as np | |||
| @@ -442,3 +442,95 @@ class Module(metaclass=ABCMeta): | |||
| loaded.append(k) | |||
| return set(loaded), set(skipped) | |||
| class QATModule(Module): | |||
| r""" | |||
| Base class of quantization related Module. Add extra forward methods | |||
| :meth:`~.QATModule.forward_qat` and :meth:`~.QATModule.forward_quantized` for | |||
| ``qat``(quantization aware training) mode and ``quantized`` mode respectively. | |||
| Use :meth:`~.QATModule.quant` to switch between ``QAT`` and ``NORMAL`` mode, | |||
| and use :meth:`~.QATModule.to_quantized` to switch to ``quantized`` mode, | |||
| which is irreversible. | |||
| If you want to recursively switch mode for all QATModule in network, use | |||
| functions in :mod:`~.quantization.quantize`. | |||
| """ | |||
| class QATMode(Enum): | |||
| DISABLED = 1 | |||
| QAT = 2 | |||
| CALIBRATION = 3 | |||
| def __init__(self): | |||
| from ..quantization import ( | |||
| QConfig, | |||
| FakeQuantize, | |||
| Observer, | |||
| ) # pylint: disable=all | |||
| super().__init__() | |||
| self.quantizing = self.QATMode.DISABLED | |||
| self.scale = None | |||
| self.inp_observer = None # type: Observer | |||
| self.weight_observer = None # type: Observer | |||
| self.act_observer = None # type: Observer | |||
| self.weight_fake_quant = None # type: FakeQuantize | |||
| self.bias_fake_quant = None # type: FakeQuantize | |||
| self.act_fake_quant = None # type: FakeQuantize | |||
| def set_qconfig(self, qconfig: "QConfig"): | |||
| self.inp_observer = qconfig.inp_observer() | |||
| self.weight_observer = qconfig.weight_observer() | |||
| self.act_observer = qconfig.act_observer() | |||
| self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype) | |||
| self.bias_fake_quant = qconfig.bias_fake_quant() | |||
| self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype) | |||
| def apply_observer(self, target: Tensor, obs: "Observer"): | |||
| return obs(target) | |||
| def apply_fakequant_with_observer( | |||
| self, target: Tensor, fq: "FakeQuantize", obs: "Observer" | |||
| ): | |||
| oup = self.apply_observer(target, obs) | |||
| return fq(oup, obs.scale, obs.zero_point) | |||
| def set_qat_mode(self, mode: QATMode): | |||
| r""" | |||
| Change ``self.quantizing`` mode, available values: ``self.QATMode.DISABLED``, | |||
| ``QAT``,``CALIBRATION``. | |||
| """ | |||
| if not isinstance(mode, self.QATMode): | |||
| raise TypeError("mode must be QATMode Enum type") | |||
| self.quantizing = mode | |||
| def to_quantized(self): | |||
| r""" | |||
| Return a new :class:`~.Module` with quantized parameters of ``self`` | |||
| according to scale and zero_point in ``self.xxx_observer``. | |||
| """ | |||
| raise NotImplementedError( | |||
| "Use megengine.quantization.quantize to register the method." | |||
| ) | |||
| @abstractmethod | |||
| def forward_qat(self, *args, **kwargs): | |||
| r""" | |||
| Forward method for ``qat`` mode. | |||
| """ | |||
| def __call__(self, *args, **kwargs): | |||
| if self.quantizing == self.QATMode.QAT: | |||
| return self.forward_qat(*args, **kwargs) | |||
| elif self.quantizing == self.QATMode.CALIBRATION: | |||
| # TODO implement the CALIBRATION | |||
| assert False | |||
| return None | |||
| else: | |||
| return self.forward(*args, **kwargs) | |||
| @@ -0,0 +1,34 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .module import QATModule | |||
| class QuantStub(QATModule): | |||
| r""" | |||
| A helper QATModule doing quantize operation on input. | |||
| """ | |||
| def forward(self, inp): | |||
| return inp | |||
| def forward_qat(self, inp): | |||
| return self.apply_fakequant_with_observer( | |||
| inp, self.act_fake_quant, self.act_observer | |||
| ) | |||
| class DequantStub(QATModule): | |||
| r""" | |||
| A helper QATModule doing de-quantize operation on input. | |||
| """ | |||
| def forward(self, inp): | |||
| return inp | |||
| def forward_qat(self, inp): | |||
| return inp | |||
| @@ -0,0 +1,11 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .concat import Concat | |||
| from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||
| from .elemwise import Elemwise | |||
| from .quant_dequant import DequantStub, QuantStub | |||
| @@ -0,0 +1,45 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Iterable | |||
| from ... import _internal as mgb | |||
| from ... import functional as F | |||
| from ... import module as Float | |||
| from ...core.tensor import Tensor | |||
| from ...quantization.utils import register_method_to_class | |||
| from ..module import Module | |||
| class Concat(Module): | |||
| r""" | |||
| A :class:`~.Module` to do quantized concat, inference only. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.scale = 1.0 | |||
| self.zero_point = 0.0 | |||
| self.output_dtype = mgb.dtype.qint8(self.scale) | |||
| def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||
| if self.training: | |||
| raise ValueError("quantized module only support inference.") | |||
| new_inps = (x.astype(self.output_dtype) for x in inps) | |||
| return F.concat(new_inps, axis) | |||
| @register_method_to_class(Float.Concat) | |||
| def to_quantized(float_module): | |||
| r""" | |||
| Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
| implemented here to avoid circular import. | |||
| """ | |||
| qmod = Concat() | |||
| qmod.output_dtype = float_module.act_observer.get_dtype() | |||
| qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() | |||
| return qmod | |||
| @@ -0,0 +1,114 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from functools import partial | |||
| from typing import Tuple, Union | |||
| import megengine._internal as mgb | |||
| from ... import module as Float | |||
| from ...core import Parameter | |||
| from ...functional import conv_bias_activation | |||
| from ...module import Conv2d | |||
| from ...quantization.utils import register_method_to_class | |||
| class _ConvBnActivation2d(Conv2d): | |||
| r"""Applies a 2D convolution over an quantized input tensor, inference only. | |||
| The parameter is same with :class: `~.Conv2d` | |||
| """ | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| conv_mode: str = "CROSS_CORRELATION", | |||
| compute_mode: str = "DEFAULT", | |||
| ): | |||
| super().__init__( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| dilation, | |||
| groups, | |||
| True, | |||
| conv_mode, | |||
| compute_mode, | |||
| ) | |||
| self.scale = 1.0 | |||
| self.zero_point = 0.0 | |||
| self.output_dtype = mgb.dtype.qint8(self.scale) | |||
| self.weight = self.weight.astype(self.output_dtype) | |||
| self.bias = self.bias.astype(mgb.dtype.qint32(self.scale)) | |||
| def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"): | |||
| inp_scale = mgb.dtype.get_scale(inp.dtype) | |||
| w_scale = mgb.dtype.get_scale(self.weight.dtype) | |||
| bias_scale = inp_scale * w_scale | |||
| return conv_bias_activation( | |||
| inp, | |||
| self.weight, | |||
| self.bias.astype(mgb.dtype.qint32(bias_scale)), | |||
| self.output_dtype, | |||
| self.stride, | |||
| self.padding, | |||
| self.dilation, | |||
| self.groups, | |||
| conv_mode=self.conv_mode, | |||
| compute_mode=self.compute_mode, | |||
| nonlinear_mode=nonlinear_mode, | |||
| ) | |||
| class ConvBn2d(_ConvBnActivation2d): | |||
| def forward(self, inp): | |||
| if self.training: | |||
| raise ValueError("quantized module only support inference.") | |||
| return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") | |||
| class ConvBnRelu2d(_ConvBnActivation2d): | |||
| def forward(self, inp): | |||
| if self.training: | |||
| raise ValueError("quantized module only support inference.") | |||
| return self.calc_conv_quantized(inp, nonlinear_mode="RELU") | |||
| def to_quantized(quantized_class, float_module): | |||
| qconv = quantized_class( | |||
| float_module.conv.in_channels, | |||
| float_module.conv.out_channels, | |||
| float_module.conv.kernel_size, | |||
| float_module.conv.stride, | |||
| float_module.conv.padding, | |||
| float_module.conv.dilation, | |||
| float_module.conv.groups, | |||
| ) | |||
| w_fold, b_fold = float_module.fold_weight_bias( | |||
| float_module.bn.running_mean, float_module.bn.running_var | |||
| ) | |||
| weight = w_fold.astype(float_module.weight_observer.get_dtype()) | |||
| qconv.output_dtype = float_module.act_observer.get_dtype() | |||
| qconv.weight = Parameter(weight.numpy()) | |||
| qconv.bias = Parameter(b_fold.numpy()) | |||
| qconv.scale, qconv.zero_point = float_module.act_observer.get_qparams() | |||
| return qconv | |||
| # replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
| # implemented here to avoid circular import. | |||
| register_method_to_class(Float.ConvBn2d)(partial(to_quantized, ConvBn2d)) | |||
| register_method_to_class(Float.ConvBnRelu2d)(partial(to_quantized, ConvBnRelu2d)) | |||
| @@ -0,0 +1,59 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ... import _internal as mgb | |||
| from ... import module as Float | |||
| from ...core import Tensor, wrap_io_tensor | |||
| from ...core.graph import _use_default_if_none | |||
| from ...quantization.utils import register_method_to_class | |||
| from ..module import Module | |||
| @wrap_io_tensor | |||
| def _elemwise_multi_type(mode, *inputs, **kwargs) -> Tensor: | |||
| if all(isinstance(i, (int, float)) for i in inputs): | |||
| device, comp_graph = _use_default_if_none(None, None) | |||
| ret = mgb.opr.elemwise_multi_type( | |||
| *inputs, mode=mode, comp_node=device, comp_graph=comp_graph, **kwargs, | |||
| ) | |||
| return ret.inferred_value[0] | |||
| return mgb.opr.elemwise_multi_type(*inputs, mode=mode, **kwargs) | |||
| class Elemwise(Module): | |||
| r""" | |||
| quantized module for elemwise operator, inference only. | |||
| :param method: the elemwise method, supported string refer to :class:`~.module.elemwise.Elemwise`. | |||
| it will do quantized operator with specified output quantized dtype. | |||
| """ | |||
| _elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode | |||
| def __init__(self, method): | |||
| super().__init__() | |||
| self.method = self._elemwise_multi_type_mode.convert("Q" + method) | |||
| self.scale = 1.0 | |||
| self.zero_point = 0.0 | |||
| self.output_dtype = mgb.dtype.qint8(self.scale) | |||
| def forward(self, *inps): | |||
| if self.training: | |||
| raise ValueError("quantized module only support inference.") | |||
| return _elemwise_multi_type(self.method, *inps, dtype=self.output_dtype) | |||
| @register_method_to_class(Float.Elemwise) | |||
| def to_quantized(float_module): | |||
| r""" | |||
| Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
| implemented here to avoid circular import. | |||
| """ | |||
| qmod = Elemwise(float_module.method.name) | |||
| qmod.output_dtype = float_module.act_observer.get_dtype() | |||
| qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() | |||
| return qmod | |||
| @@ -0,0 +1,61 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ... import _internal as mgb | |||
| from ... import module as Float | |||
| from ...quantization.utils import register_method_to_class | |||
| from ..module import Module | |||
| class QuantStub(Module): | |||
| r""" | |||
| A helper quantize operation on input and inference only. | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.scale = 1.0 | |||
| self.zero_point = 0.0 | |||
| self.output_dtype = mgb.dtype.qint8(self.scale) | |||
| def forward(self, inp): | |||
| if self.training: | |||
| raise ValueError("quantized module only support inference.") | |||
| return inp.astype(self.output_dtype) | |||
| class DequantStub(Module): | |||
| r""" | |||
| A helper de-quantize operation and inference only. | |||
| """ | |||
| def forward(self, inp): | |||
| if self.training: | |||
| raise ValueError("quantized module only support inference.") | |||
| return inp.astype("float32") | |||
| @register_method_to_class(Float.QuantStub) | |||
| def to_quantized(float_module): | |||
| r""" | |||
| Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
| implemented here to avoid circular import. | |||
| """ | |||
| qmod = QuantStub() | |||
| qmod.output_dtype = float_module.act_observer.get_dtype() | |||
| qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() | |||
| return qmod | |||
| @register_method_to_class(Float.DequantStub) | |||
| def to_quantized(float_module): | |||
| r""" | |||
| Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
| implemented here to avoid circular import. | |||
| """ | |||
| qmod = DequantStub() | |||
| return qmod | |||
| @@ -68,6 +68,7 @@ class Sequential(Module): | |||
| def __setitem__(self, idx, module): | |||
| key = self.layer_keys[idx] | |||
| self.layer_values[idx] = module | |||
| return setattr(self, key, module) | |||
| def __delitem__(self, idx): | |||
| @@ -0,0 +1,11 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .fake_quant import FakeQuantize | |||
| from .observer import Observer | |||
| from .qconfig import QConfig, ema_fakequant_qconfig, min_max_fakequant_qconfig | |||
| from .quantize import quantize, quantize_qat | |||
| @@ -0,0 +1,48 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .. import functional as F | |||
| from .._internal.dtype import _metadata_dict | |||
| from ..module import Module | |||
| from .observer import Round | |||
| class FakeQuantize(Module): | |||
| r""" | |||
| A module to do quant and dequant according to observer's scale and zero_point. | |||
| """ | |||
| def __init__(self, dtype: str, enable: bool = True): | |||
| super().__init__() | |||
| if not dtype in _metadata_dict.keys(): | |||
| raise ValueError( | |||
| "unknown dtype: {}, only support {}".format( | |||
| dtype, _metadata_dict.keys() | |||
| ) | |||
| ) | |||
| self.dtype = dtype | |||
| self.qmin = _metadata_dict[dtype].qmin | |||
| self.qmax = _metadata_dict[dtype].qmax | |||
| self.enabled = enable | |||
| def enable(self): | |||
| self.enabled = True | |||
| def disable(self): | |||
| self.enabled = False | |||
| def forward(self, inp, scale, zero_point): | |||
| if self.enabled: | |||
| # Quant | |||
| oup = Round()(inp / scale) + zero_point | |||
| # clip | |||
| oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||
| # DeQuant | |||
| oup = (oup - zero_point) * scale | |||
| return oup | |||
| return inp | |||
| @@ -0,0 +1,193 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from abc import abstractmethod | |||
| import numpy as np | |||
| from .. import functional as F | |||
| from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||
| from ..core import Buffer, Function, ones, tensor, zeros | |||
| from ..module import Module | |||
| class Round(Function): | |||
| def forward(self, x): | |||
| return x.round() | |||
| def backward(self, output_grads): | |||
| return output_grads | |||
| class Observer(Module): | |||
| r""" | |||
| A base class for Observer Module. | |||
| :param dtype: a string indicating to collect scale and zero_point of which dtype | |||
| """ | |||
| def __init__(self, dtype="qint8"): | |||
| super().__init__() | |||
| if dtype not in _metadata_dict.keys(): | |||
| raise ValueError( | |||
| "unknown dtype: {}, only support {}".format( | |||
| dtype, _metadata_dict.keys() | |||
| ) | |||
| ) | |||
| self.dtype = dtype | |||
| self.qmin = _metadata_dict[dtype].qmin | |||
| self.qmax = _metadata_dict[dtype].qmax | |||
| self.zero_point, self.scale = None, None | |||
| self.enabled = True | |||
| def get_dtype(self): | |||
| scale, zero_point = self.get_qparams() | |||
| numpy_scale = None if scale is None else scale.numpy()[0] | |||
| numpy_zero_point = None if zero_point is None else zero_point.numpy()[0] | |||
| return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point) | |||
| def enable(self): | |||
| self.enabled = True | |||
| def disable(self): | |||
| self.enabled = False | |||
| @abstractmethod | |||
| def forward(self, x): | |||
| pass | |||
| @abstractmethod | |||
| def get_qparams(self, **kwargs): | |||
| pass | |||
| class IdentityObserver(Observer): | |||
| r""" | |||
| An test Observer that always return scale:1 and zero_point:0. | |||
| """ | |||
| def __init__(self, *args, **kwargs): | |||
| super().__init__(*args, **kwargs) | |||
| self.zero_point = ones((1), dtype="float32") | |||
| self.scale = zeros((1), dtype="float32") | |||
| def forward(self, x): | |||
| return x | |||
| def get_qparams(self): | |||
| return self.scale, self.zero_point | |||
| class MinMaxObserver(Observer): | |||
| def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs): | |||
| super().__init__(*args, **kwargs) | |||
| self.symmetric = symmetric | |||
| if self.symmetric: | |||
| # assert qmin + qmax == -1, 'when reduce_range, qmin + qmax shoule equal -1' | |||
| self.zero_point = tensor((self.qmin + self.qmax + 1) // 2) | |||
| self.min_val = Buffer(0.0, dtype=np.float32) | |||
| self.max_val = Buffer(0.0, dtype=np.float32) | |||
| self.scale_limit = eps | |||
| # flag is used by cond_take, first time will be first flag, and after will be set as not_flag | |||
| self.first_flag = Buffer(np.array([1, 0], dtype=np.int32)) | |||
| self.not_flag = Buffer(np.array([0, 1], dtype=np.int32)) | |||
| def set_min_max(self, tmp_min, tmp_max): | |||
| # FIXME: cond_take will destory shape, use reshape to reset shape | |||
| tmp_min = tmp_min.reshape(1) | |||
| tmp_max = tmp_max.reshape(1) | |||
| if self.training: | |||
| F.zero_grad( | |||
| F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0) | |||
| ) | |||
| F.zero_grad( | |||
| F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0) | |||
| ) | |||
| F.zero_grad( | |||
| F.add_update( | |||
| self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0 | |||
| ) | |||
| ) | |||
| # FIXME: add_update is applied after the whole trace procedure in `symbolic=True` | |||
| # mode. So use tmp_min/tmp_max to calc and save scale/zero_point for further | |||
| # calculation in FakeQuant. | |||
| self.set_scale_zero_point(tmp_min, tmp_max) | |||
| def set_scale_zero_point(self, tmp_min, tmp_max): | |||
| if self.symmetric: | |||
| symmetric_max_vals = F.maximum(-tmp_min, tmp_max) | |||
| # use maximun to avoid scale too small at the begin | |||
| self.scale = F.maximum( | |||
| symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit | |||
| ) | |||
| # zero_point = self.zero_point | |||
| else: | |||
| # use maximun to avoid scale too small at the begin | |||
| self.scale = F.maximum( | |||
| (tmp_max - tmp_min) / (self.qmax - self.qmin), self.scale_limit | |||
| ) | |||
| # caculate zero_point | |||
| self.zero_point = self.qmin - Round()((tmp_min / self.scale)) | |||
| def get_qparams(self): | |||
| # scale and zero_point is runtime tensor rather than Buffer, | |||
| # so need to re-calc if min_val and max_val are loaded. | |||
| if self.scale is None: | |||
| self.set_scale_zero_point(self.min_val, self.max_val) | |||
| return self.scale, self.zero_point | |||
| def forward(self, x_orig): | |||
| if self.enabled: | |||
| # stop gradient | |||
| x = F.zero_grad(x_orig) | |||
| # find max and min | |||
| tmp_min, _ = F.cond_take( | |||
| self.first_flag, F.concat([x.min(), F.minimum(self.min_val, x.min())]) | |||
| ) | |||
| tmp_max, _ = F.cond_take( | |||
| self.first_flag, F.concat([x.max(), F.maximum(self.max_val, x.max())]) | |||
| ) | |||
| self.set_min_max(tmp_min, tmp_max) | |||
| return x_orig | |||
| class ExponentialMovingAverageObserver(MinMaxObserver): | |||
| def __init__(self, momentum=0.9, *args, **kwargs): | |||
| super().__init__(*args, **kwargs) | |||
| self.momentum = Buffer(momentum) | |||
| def set_momentum(self, momentum): | |||
| self.momentum.set_value(momentum) | |||
| def forward(self, x_orig): | |||
| if self.enabled: | |||
| # stop gradient | |||
| x = F.zero_grad(x_orig) | |||
| # Exponential Moving Average | |||
| tmp_min, _ = F.cond_take( | |||
| self.first_flag, | |||
| F.concat( | |||
| [ | |||
| x.min(), | |||
| self.momentum * self.min_val + (1 - self.momentum) * x.min(), | |||
| ] | |||
| ), | |||
| ) | |||
| tmp_max, _ = F.cond_take( | |||
| self.first_flag, | |||
| F.concat( | |||
| [ | |||
| x.max(), | |||
| self.momentum * self.max_val + (1 - self.momentum) * x.max(), | |||
| ] | |||
| ), | |||
| ) | |||
| self.set_min_max(tmp_min, tmp_max) | |||
| return x_orig | |||
| @@ -0,0 +1,82 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from functools import partial | |||
| from ..module import Module | |||
| from .fake_quant import FakeQuantize | |||
| from .observer import ExponentialMovingAverageObserver, MinMaxObserver | |||
| class QConfig: | |||
| """ | |||
| A config class indicating how to do quantize toward :class:`~.QATModule`'s | |||
| ``activation``, ``weight`` and ``bias``. | |||
| And ``fake_quant`` parameter to indicate | |||
| See :meth:`~.QATModule.set_qconfig` for detail usage. | |||
| :param inp_observer: interface to instantiate an :class:`~.Observer` indicating | |||
| how to collect scales and zero_point of input. | |||
| :param weight_observer: similar to ``inp_observer`` but toward weight. | |||
| :param act_observer: similar to ``inp_observer`` but toward activation. | |||
| :param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating | |||
| how to do fake_quant calculation. can be invoked multi times to get different | |||
| instance for each target tensor, for better control on enable and disable. | |||
| :param bias_fake_quant: similar to ``fake_quant``, but usually need to set ``dtype`` | |||
| in advance, for bias's dtype is unable to be inferred from observer. | |||
| Examples: | |||
| .. code-block:: | |||
| # Default EMA QConfig for QAT. | |||
| ema_fakequant_qconfig = QConfig( | |||
| inp_observer=ExponentialMovingAverageObserver, | |||
| weight_observer=ExponentialMovingAverageObserver, | |||
| act_observer=ExponentialMovingAverageObserver, | |||
| fake_quant=FakeQuantize, | |||
| ) | |||
| """ | |||
| def __init__( | |||
| self, act_observer, weight_observer, inp_observer, fake_quant, bias_fake_quant, | |||
| ): | |||
| if ( | |||
| isinstance(act_observer, Module) | |||
| or isinstance(weight_observer, Module) | |||
| or isinstance(inp_observer, Module) | |||
| ): | |||
| raise ValueError( | |||
| "QConfig must not receive observer instance, please pass observer" | |||
| " class generator using `partial(Observer, ...)` instead. Use" | |||
| " partial(MyObserver, x=1) to override arguments to constructor if needed" | |||
| ) | |||
| self.act_observer = act_observer | |||
| self.weight_observer = weight_observer | |||
| self.inp_observer = inp_observer | |||
| self.fake_quant = fake_quant | |||
| self.bias_fake_quant = bias_fake_quant | |||
| # Default QAT QConfigs | |||
| min_max_fakequant_qconfig = QConfig( | |||
| inp_observer=MinMaxObserver, | |||
| weight_observer=MinMaxObserver, | |||
| act_observer=MinMaxObserver, | |||
| fake_quant=FakeQuantize, | |||
| bias_fake_quant=partial(FakeQuantize, dtype="qint32"), | |||
| ) | |||
| ema_fakequant_qconfig = QConfig( | |||
| inp_observer=ExponentialMovingAverageObserver, | |||
| weight_observer=MinMaxObserver, | |||
| act_observer=ExponentialMovingAverageObserver, | |||
| fake_quant=FakeQuantize, | |||
| bias_fake_quant=partial(FakeQuantize, dtype="qint32"), | |||
| ) | |||
| @@ -0,0 +1,113 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from copy import deepcopy | |||
| from ..module import Module, QATModule, Sequential, quantized | |||
| from .qconfig import QConfig, ema_fakequant_qconfig | |||
| def quantize(module: Module, inplace=True): | |||
| r""" | |||
| Recursively convert `module` to `quantized` mode through :meth:`~.Module.apply`. | |||
| :param module: root module to do convert recursively. | |||
| """ | |||
| if not inplace: | |||
| module = deepcopy(module) | |||
| def is_qat_module(obj): | |||
| return isinstance(obj, QATModule) | |||
| # no need to pass prefix and get pure key of parent Module. | |||
| for key, submodule, parent in module._flatten( | |||
| with_key=True, with_parent=True, predicate=is_qat_module | |||
| ): | |||
| if isinstance(parent, Sequential): | |||
| # cannnot use setattr to be compatible with Sequential's ``__setitem__`` | |||
| parent[int(key.split(".")[-1])] = submodule.to_quantized() | |||
| else: | |||
| setattr(parent, key.split(".")[-1], submodule.to_quantized()) | |||
| def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): | |||
| r""" | |||
| Recursively convert `module` to `qat` mode through :meth:`~.Module.apply` | |||
| and set qconfig relatively. | |||
| :param module: root module to do convert recursively. | |||
| :param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. | |||
| default is :any:`~.qconfig.ema_fakequant_qconfig`. | |||
| """ | |||
| def fn(mod: Module): | |||
| if isinstance(mod, QATModule): | |||
| mod.set_qat_mode(QATModule.QATMode.QAT) | |||
| mod.set_qconfig(qconfig) | |||
| module.apply(fn) | |||
| def disable_fake_quant(module: Module): | |||
| r""" | |||
| Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply` | |||
| :param module: root module to do disable fake quantization recursively. | |||
| """ | |||
| def fn(mod): | |||
| if isinstance(mod, QATModule): | |||
| mod.act_fake_quant.disable() | |||
| mod.weight_fake_quant.disable() | |||
| mod.inp_fake_quant.disable() | |||
| module.apply(fn) | |||
| def disable_observer(module: Module): | |||
| r""" | |||
| Recursively disable `module` observer in QATModule through :meth:`~.Module.apply` | |||
| :param module: root module to do disable observer recursively. | |||
| """ | |||
| def fn(mod): | |||
| if isinstance(mod, QATModule): | |||
| mod.act_observer.disable() | |||
| module.apply(fn) | |||
| def enable_fake_quant(module: Module): | |||
| r""" | |||
| Recursively enable `module` fake quantization in QATModule through :meth:`~.Module.apply` | |||
| :param module: root module to do enable fake quantization recursively. | |||
| """ | |||
| def fn(mod): | |||
| if isinstance(mod, QATModule): | |||
| mod.act_fake_quant.enable() | |||
| mod.weight_fake_quant.enable() | |||
| mod.inp_fake_quant.enable() | |||
| module.apply(fn) | |||
| def enable_observer(module: Module): | |||
| r""" | |||
| Recursively enable `module` observer in QATModule through :meth:`~.Module.apply` | |||
| :param module: root module to do enable observer recursively. | |||
| """ | |||
| def fn(mod): | |||
| if isinstance(mod, QATModule): | |||
| mod.act_observer.enable() | |||
| module.apply(fn) | |||
| @@ -0,0 +1,23 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from functools import partial, update_wrapper, wraps | |||
| def register_method_to_class(cls): | |||
| def decorator(func): | |||
| @wraps(func) | |||
| def wrapper(self, *args, **kwargs): | |||
| return func(self, *args, **kwargs) | |||
| if isinstance(func, partial): | |||
| update_wrapper(func, func.func) | |||
| setattr(cls, func.__name__, wrapper) | |||
| return func | |||
| return decorator | |||
| @@ -7,10 +7,12 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| import pytest | |||
| from helpers import opr_test | |||
| import megengine._internal as mgb | |||
| import megengine.functional as F | |||
| from megengine import Buffer, jit, tensor | |||
| from megengine import Buffer, Parameter, is_cuda_available, jit, tensor | |||
| from megengine.test import assertTensorClose | |||
| @@ -332,3 +334,108 @@ def test_binary_cross_entropy(): | |||
| {"input": [data2, label2], "output": expect2,}, | |||
| ] | |||
| opr_test(cases, F.binary_cross_entropy, compare_fn=compare_fn) | |||
| @pytest.mark.skip | |||
| def test_conv_bias(): | |||
| inp_scale = 0.01 | |||
| w_scale = 0.02 | |||
| outp_scale = 0.1 | |||
| inp_dtype = mgb.dtype.qint8(inp_scale) | |||
| w_dtype = mgb.dtype.qint8(w_scale) | |||
| b_dtype = mgb.dtype.qint32(inp_scale * w_scale) | |||
| out_dtype = mgb.dtype.qint8(outp_scale) | |||
| def run( | |||
| N, | |||
| IC, | |||
| OC, | |||
| IH, | |||
| IW, | |||
| KH, | |||
| KW, | |||
| PH, | |||
| PW, | |||
| SH, | |||
| SW, | |||
| has_bias=True, | |||
| nonlinear_mode="IDENTITY", | |||
| ): | |||
| inp_v = np.random.normal(size=(N, IC, IH, IW)) | |||
| w_v = np.random.normal(size=(OC, IC, KW, KW)) | |||
| b_v = np.random.normal(size=(1, OC, 1, 1)) | |||
| inp_scale = mgb.dtype.get_scale(inp_dtype) | |||
| w_scale = mgb.dtype.get_scale(w_dtype) | |||
| b_scale = mgb.dtype.get_scale(b_dtype) | |||
| inpv = mgb.dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype) | |||
| wv = mgb.dtype.convert_to_qint8(w_v * w_scale, w_dtype) | |||
| bv = mgb.dtype.convert_to_qint32(b_v * b_scale, b_dtype) | |||
| inp_int8 = tensor(inpv, dtype=inp_dtype) | |||
| w_int8 = Parameter(wv, dtype=w_dtype) | |||
| b_int32 = Parameter(bv, dtype=b_dtype) | |||
| inp_fp32 = inp_int8.astype("float32") | |||
| w_fp32 = w_int8.astype("float32") | |||
| b_fp32 = b_int32.astype("float32") | |||
| jit.trace.enabled = True | |||
| b_symbolic = True | |||
| def convert_to_nchw4(var): | |||
| return var.reshape( | |||
| var.shapeof(0), var.shapeof(1) // 4, 4, var.shapeof(2), var.shapeof(3) | |||
| ).dimshuffle(0, 1, 3, 4, 2) | |||
| @jit.trace(symbolic=b_symbolic) | |||
| def run_conv2d(inp, w, b): | |||
| O = F.conv2d( | |||
| inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW), | |||
| ) | |||
| if nonlinear_mode == "RELU": | |||
| return F.relu(O) | |||
| else: | |||
| return O | |||
| @jit.trace(symbolic=b_symbolic) | |||
| def run_conv_bias(inp, w, b, format="NCHW"): | |||
| b = b if has_bias else np.zeros_like(b) | |||
| if format == "NCHW4": | |||
| inp = convert_to_nchw4(inp) | |||
| w = convert_to_nchw4(w) | |||
| b = F.flatten(b) | |||
| return F.conv_bias_activation( | |||
| inp, | |||
| w, | |||
| b, | |||
| stride=(SH, SW), | |||
| padding=(PH, PW), | |||
| dtype=out_dtype, | |||
| nonlinear_mode=nonlinear_mode, | |||
| ) | |||
| format = "NCHW4" if is_cuda_available() else "NCHW" | |||
| expected = run_conv2d(inp_fp32, w_fp32, b_fp32) | |||
| expected = expected.astype(out_dtype).astype("float32") | |||
| result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype( | |||
| "float32" | |||
| ) | |||
| if format == "NCHW4": | |||
| result = result.dimshuffle(0, 1, 4, 2, 3) | |||
| expected = F.flatten(expected) | |||
| result = F.flatten(result) | |||
| assertTensorClose(result.numpy(), expected.numpy()) | |||
| if not is_cuda_available(): | |||
| run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False) | |||
| run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False) | |||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False) | |||
| run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1) | |||
| run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1) | |||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2) | |||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU") | |||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") | |||