GitOrigin-RevId: f6dbd1f4c0
tags/v1.1.0
| @@ -45,8 +45,6 @@ def conv_bias_activation( | |||
| :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
| 'CROSS_CORRELATION' | |||
| :param dtype: support for ``np.dtype``, Default: np.int8 | |||
| :param scale: scale if use quantization, Default: 0.0 | |||
| :param zero_point: scale if use quantization quint8, Default: 0.0 | |||
| :type compute_mode: string or | |||
| :class:`P.Convolution.ComputeMode`. | |||
| :param compute_mode: when set to "DEFAULT", no special requirements will be | |||
| @@ -75,3 +73,63 @@ def conv_bias_activation( | |||
| ) | |||
| (outputs,) = apply(op, inp, weight, bias) | |||
| return outputs | |||
| def batch_conv_bias_activation( | |||
| inp: Tensor, | |||
| weight: Tensor, | |||
| bias: Tensor, | |||
| dtype=None, | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| nonlinear_mode="IDENTITY", | |||
| conv_mode="CROSS_CORRELATION", | |||
| compute_mode="DEFAULT", | |||
| ) -> Tensor: | |||
| """ | |||
| Batch convolution bias with activation operation, only for inference. | |||
| :param inp: feature map of the convolution operation. | |||
| :param weight: convolution kernel in batched way. | |||
| :param bias: bias added to the result of convolution | |||
| :param stride: stride of the 2D convolution operation. Default: 1 | |||
| :param padding: size of the paddings added to the input on both sides of its spatial dimensions. Only zero-padding is supported. Default: 0 | |||
| :param dilation: dilation of the 2D convolution operation. Default: 1 | |||
| :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
| and the shape of weight should be `(groups, out_channel // groups, | |||
| in_channels // groups, height, width)`. | |||
| :type conv_mode: string or :class:`P.Convolution.Mode`. | |||
| :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
| 'CROSS_CORRELATION' | |||
| :param dtype: support for ``np.dtype``, Default: np.int8 | |||
| :type compute_mode: string or | |||
| :class:`P.Convolution.ComputeMode`. | |||
| :param compute_mode: when set to "DEFAULT", no special requirements will be | |||
| placed on the precision of intermediate results. When set to "FLOAT32", | |||
| "Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype. | |||
| """ | |||
| ph, pw = _pair(padding) | |||
| sh, sw = _pair_nonzero(stride) | |||
| dh, dw = _pair_nonzero(dilation) | |||
| sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
| op = builtin.BatchConvBiasForward( | |||
| stride_h=sh, | |||
| stride_w=sw, | |||
| pad_h=ph, | |||
| pad_w=pw, | |||
| dilate_h=dh, | |||
| dilate_w=dw, | |||
| dtype=dtype, | |||
| format="NCHW", | |||
| strategy=get_conv_execution_strategy(), | |||
| nonlineMode=nonlinear_mode, | |||
| mode=conv_mode, | |||
| compute_mode=compute_mode, | |||
| sparse=sparse_type, | |||
| ) | |||
| (outputs,) = apply(op, inp, weight, bias) | |||
| return outputs | |||
| @@ -9,6 +9,7 @@ | |||
| from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | |||
| from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d | |||
| from .batch_matmul_activation import BatchMatMulActivation | |||
| from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||
| from .concat import Concat | |||
| from .conv import Conv1d, Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | |||
| @@ -0,0 +1,67 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| from ..functional import matmul, relu | |||
| from ..tensor import Parameter | |||
| from . import init | |||
| from .module import Module | |||
| class BatchMatMulActivation(Module): | |||
| r""" | |||
| Batched MatMul with activation(only relu supported), no transpose anywhere. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| batch: int, | |||
| in_features: int, | |||
| out_features: int, | |||
| bias: bool = True, | |||
| nonlinear_mode="IDENTITY", | |||
| **kwargs | |||
| ): | |||
| super().__init__(**kwargs) | |||
| self.batch = batch | |||
| self.out_features = out_features | |||
| self.in_features = in_features | |||
| w_shape = (batch, out_features, in_features) | |||
| self.weight = Parameter(np.zeros(w_shape, dtype=np.float32)) | |||
| self.bias = None | |||
| if bias: | |||
| b_shape = (out_features,) | |||
| self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) | |||
| self.nonlinear_mode = nonlinear_mode | |||
| self.reset_parameters() | |||
| def _get_fanin(self): | |||
| return self.in_features | |||
| def reset_parameters(self) -> None: | |||
| fanin = self._get_fanin() | |||
| std = np.sqrt(1 / fanin) | |||
| init.normal_(self.weight, 0.0, std) | |||
| if self.bias is not None: | |||
| init.zeros_(self.bias) | |||
| def _calc_linear(self, x, weight, bias): | |||
| res = matmul(weight, x) | |||
| if self.bias is not None: | |||
| res += bias | |||
| if self.nonlinear_mode == "RELU": | |||
| res = relu(res) | |||
| return res | |||
| def forward(self, x): | |||
| return self._calc_linear(x, self.weight, self.bias) | |||
| def _module_info_string(self) -> str: | |||
| return "batch={}, in_features={}, out_features={}, bias={}".format( | |||
| self.batch, self.in_features, self.out_features, self.bias is not None | |||
| ) | |||
| @@ -5,6 +5,7 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .batch_matmul_activation import BatchMatMulActivation | |||
| from .concat import Concat | |||
| from .conv import Conv2d, ConvRelu2d | |||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
| @@ -0,0 +1,30 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ...quantization.utils import fake_quant_bias | |||
| from .. import batch_matmul_activation as Float | |||
| from .module import QATModule | |||
| class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule): | |||
| def forward(self, inp): | |||
| w_qat = self.apply_quant_weight(self.weight) | |||
| b_qat = fake_quant_bias(self.bias, inp, w_qat) | |||
| return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat)) | |||
| @classmethod | |||
| def from_float_module(cls, float_module: Float.BatchMatMulActivation): | |||
| qat_module = cls( | |||
| float_module.batch, | |||
| float_module.in_features, | |||
| float_module.out_features, | |||
| float_module.bias is not None, | |||
| ) | |||
| qat_module.weight = float_module.weight | |||
| qat_module.bias = float_module.bias | |||
| return qat_module | |||
| @@ -5,6 +5,7 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .batch_matmul_activation import BatchMatMulActivation | |||
| from .concat import Concat | |||
| from .conv import Conv2d, ConvRelu2d | |||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
| @@ -0,0 +1,76 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from typing import Tuple, Union | |||
| import numpy as np | |||
| from ... import module as Float | |||
| from ...core.tensor import dtype | |||
| from ...functional import expand_dims, squeeze | |||
| from ...functional.quantized import batch_conv_bias_activation | |||
| from ...tensor import Parameter | |||
| from ..qat import batch_matmul_activation as QAT | |||
| from .module import QuantizedModule | |||
| class BatchMatMulActivation(Float.BatchMatMulActivation, QuantizedModule): | |||
| def __init__( | |||
| self, | |||
| batch: int, | |||
| in_features: int, | |||
| out_features: int, | |||
| bias: bool = True, | |||
| nonlinear_mode="IDENTITY", | |||
| dtype=None, | |||
| **kwargs | |||
| ): | |||
| super().__init__(batch, in_features, out_features, bias, **kwargs) | |||
| self.output_dtype = dtype | |||
| def calc_bmm_quantized(self, inp): | |||
| inp_scale = dtype.get_scale(inp.dtype) | |||
| w_scale = dtype.get_scale(self.weight.dtype) | |||
| bias_scale = inp_scale * w_scale | |||
| inp = expand_dims(inp, [-1]) | |||
| res = batch_conv_bias_activation( | |||
| inp, | |||
| self.weight, | |||
| self.bias.astype(dtype.qint32(bias_scale)), | |||
| dtype=self.output_dtype, | |||
| stride=1, | |||
| padding=0, | |||
| dilation=1, | |||
| groups=1, | |||
| nonlinear_mode=self.nonlinear_mode, | |||
| ) | |||
| return squeeze(res, -1) | |||
| @classmethod | |||
| def from_qat_module(cls, qat_module: QAT.BatchMatMulActivation): | |||
| output_dtype = qat_module.get_activation_dtype() | |||
| qbmm = cls( | |||
| qat_module.batch, | |||
| qat_module.in_features, | |||
| qat_module.out_features, | |||
| qat_module.bias is not None, | |||
| dtype=output_dtype, | |||
| ) | |||
| weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | |||
| weight = expand_dims(weight, [-1, -2]) | |||
| qbmm.weight = Parameter(weight.numpy()) | |||
| if qat_module.bias is not None: | |||
| bias = qat_module.bias.reshape((1, qbmm.out_features, 1, 1)) | |||
| qbmm.bias = Parameter(bias.numpy()) | |||
| else: | |||
| qbmm.bias = Parameter( | |||
| np.zeros((1, qbmm.out_features, 1, 1), dtype=np.float32) | |||
| ) | |||
| return qbmm | |||
| def forward(self, inp): | |||
| return self.calc_bmm_quantized(inp) | |||
| @@ -20,6 +20,7 @@ from megengine import Parameter, Tensor, is_cuda_available, tensor | |||
| from megengine.core._trace_option import use_symbolic_shape | |||
| from megengine.core.autodiff.grad import Grad | |||
| from megengine.core.tensor.utils import make_shape_tuple | |||
| from megengine.distributed.helper import get_device_count_by_fork | |||
| def test_where(): | |||
| @@ -420,7 +421,9 @@ def test_nms(): | |||
| np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32)) | |||
| @pytest.mark.skip(reason="cuda does not support nchw int8") | |||
| @pytest.mark.skipif( | |||
| get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8" | |||
| ) | |||
| def test_conv_bias(): | |||
| inp_scale = 1.5 | |||
| w_scale = 2.5 | |||
| @@ -446,7 +449,7 @@ def test_conv_bias(): | |||
| nonlinear_mode="IDENTITY", | |||
| ): | |||
| inp_v = np.random.normal(size=(N, IC, IH, IW)) | |||
| w_v = np.random.normal(size=(OC, IC, KW, KW)) | |||
| w_v = np.random.normal(size=(OC, IC, KH, KW)) | |||
| b_v = np.random.normal(size=(1, OC, 1, 1)) | |||
| inp_scale = dtype.get_scale(inp_dtype) | |||
| w_scale = dtype.get_scale(w_dtype) | |||
| @@ -486,13 +489,12 @@ def test_conv_bias(): | |||
| inp = convert_to_nchw4(inp) | |||
| w = convert_to_nchw4(w) | |||
| b = convert_to_nchw4(b) | |||
| return F.nn.conv_bias_activation( | |||
| return F.quantized.conv_bias_activation( | |||
| inp, | |||
| w, | |||
| b, | |||
| stride=(SH, SW), | |||
| padding=(PH, PW), | |||
| format=format, | |||
| dtype=out_dtype, | |||
| nonlinear_mode=nonlinear_mode, | |||
| ) | |||
| @@ -522,6 +524,59 @@ def test_conv_bias(): | |||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") | |||
| @pytest.mark.skipif( | |||
| get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda" | |||
| ) | |||
| def test_batch_conv_bias(): | |||
| inp_scale = 1.5 | |||
| w_scale = 2.5 | |||
| outp_scale = 1.5 | |||
| inp_dtype = dtype.qint8(inp_scale) | |||
| w_dtype = dtype.qint8(w_scale) | |||
| b_dtype = dtype.qint32(inp_scale * w_scale) | |||
| out_dtype = dtype.qint8(outp_scale) | |||
| def run( | |||
| N, IC, OC, IH, IW, KH, KW, PH, PW, SH, SW, has_bias=True, | |||
| ): | |||
| inp_v = np.random.normal(size=(N, IC, IH, IW)) | |||
| w_v = np.random.normal(size=(N, OC, IC, KH, KW)) | |||
| b_v = np.random.normal(size=(1, OC, 1, 1)) | |||
| inp_scale = dtype.get_scale(inp_dtype) | |||
| w_scale = dtype.get_scale(w_dtype) | |||
| b_scale = dtype.get_scale(b_dtype) | |||
| inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype) | |||
| wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype) | |||
| bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype) | |||
| inp_int8 = tensor(inpv, dtype=inp_dtype) | |||
| w_int8 = Parameter(wv, dtype=w_dtype) | |||
| b_int32 = Parameter(bv, dtype=b_dtype) | |||
| inp_fp32 = inp_int8.astype("float32") | |||
| w_fp32 = w_int8.astype("float32") | |||
| b_fp32 = b_int32.astype("float32") | |||
| def run_batch_conv_bias(inp, w, b): | |||
| b = b if has_bias else Parameter(np.zeros_like(b.numpy())) | |||
| result = F.quantized.batch_conv_bias_activation( | |||
| inp, w, b, stride=(SH, SW), padding=(PH, PW), dtype=out_dtype, | |||
| ) | |||
| return result.astype("float32") | |||
| expected = F.conv2d(inp_fp32, w_fp32[0], b_fp32 if has_bias else None)[0] | |||
| expected = expected.astype(out_dtype).astype("float32") | |||
| expected = F.flatten(expected) | |||
| result = run_batch_conv_bias(inp_int8, w_int8, b_int32) | |||
| result = F.flatten(result) | |||
| np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale) | |||
| run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | |||
| def test_zero_stride_numpy_array(): | |||
| inp = np.random.randn(3, 224, 224).astype(np.float32) | |||
| inp = inp[np.newaxis, :] | |||
| @@ -1,9 +1,15 @@ | |||
| import io | |||
| from itertools import product | |||
| import numpy as np | |||
| import pytest | |||
| from megengine import tensor | |||
| import megengine.utils.comp_graph_tools as cgtools | |||
| from megengine import jit, tensor | |||
| from megengine.distributed.helper import get_device_count_by_fork | |||
| from megengine.functional import expand_dims | |||
| from megengine.module import ( | |||
| BatchMatMulActivation, | |||
| Conv2d, | |||
| ConvBn2d, | |||
| ConvRelu2d, | |||
| @@ -11,7 +17,12 @@ from megengine.module import ( | |||
| Module, | |||
| QuantStub, | |||
| ) | |||
| from megengine.quantization.quantize import disable_fake_quant, quantize_qat | |||
| from megengine.quantization.quantize import ( | |||
| disable_fake_quant, | |||
| enable_fake_quant, | |||
| quantize, | |||
| quantize_qat, | |||
| ) | |||
| def test_qat_convbn2d(): | |||
| @@ -88,3 +99,107 @@ def test_qat_conv(): | |||
| qat_net.eval() | |||
| qat_outputs = qat_net(inputs) | |||
| np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||
| @pytest.mark.skipif( | |||
| get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda" | |||
| ) | |||
| def test_qat_batchmatmul_activation(): | |||
| batch = 4 | |||
| in_features = 8 | |||
| out_features = 4 | |||
| class TestNet(Module): | |||
| def __init__(self, bias): | |||
| super().__init__() | |||
| self.quant = QuantStub() | |||
| self.dequant = DequantStub() | |||
| self.batch_mm = BatchMatMulActivation( | |||
| batch, in_features, out_features, bias=bias | |||
| ) | |||
| def forward(self, inp): | |||
| out = self.quant(inp) | |||
| out = self.batch_mm(out) | |||
| out = self.dequant(out) | |||
| return out | |||
| inputs = tensor( | |||
| np.random.randn(batch, in_features, out_features).astype(np.float32) | |||
| ) | |||
| for bias in (True, False): | |||
| net = TestNet(bias) | |||
| net.train() | |||
| qat_net = quantize_qat(net, inplace=False) | |||
| disable_fake_quant(qat_net) | |||
| normal_outputs = net(inputs) | |||
| qat_outputs = qat_net(inputs) | |||
| np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||
| net.eval() | |||
| normal_outputs = net(inputs) | |||
| qat_net.eval() | |||
| qat_outputs = qat_net(inputs) | |||
| np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||
| @pytest.mark.skip(reason="FIXME: abnormal exit") | |||
| def test_quantize_batchmatmul_activation(): | |||
| batch = 4 | |||
| in_features = 8 | |||
| out_features = 4 | |||
| class TestNet(Module): | |||
| def __init__(self, bias): | |||
| super().__init__() | |||
| self.quant = QuantStub() | |||
| self.dequant = DequantStub() | |||
| self.batch_mm = BatchMatMulActivation( | |||
| batch, in_features, out_features, bias=bias | |||
| ) | |||
| def forward(self, inp): | |||
| out = self.quant(inp) | |||
| out = self.batch_mm(out) | |||
| out = expand_dims(out, -1) | |||
| out = self.dequant(out) | |||
| return out | |||
| inputs = tensor( | |||
| np.random.randn(batch, in_features, out_features).astype(np.float32) | |||
| ) | |||
| for bias in (True, False): | |||
| net = TestNet(bias) | |||
| net.train() | |||
| qat_net = quantize_qat(net, inplace=False) | |||
| disable_fake_quant(qat_net) | |||
| normal_outputs = net(inputs) | |||
| qat_outputs = qat_net(inputs) | |||
| np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||
| net.eval() | |||
| normal_outputs = net(inputs) | |||
| qat_net.eval() | |||
| qat_outputs = qat_net(inputs) | |||
| np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||
| enable_fake_quant(qat_net) | |||
| qat_outputs = qat_net(inputs) | |||
| qnet = quantize(qat_net, inplace=False) | |||
| qnet.eval() | |||
| quantize_outputs = qnet(inputs) | |||
| np.testing.assert_allclose( | |||
| qat_outputs.numpy(), quantize_outputs.numpy(), atol=1e-6 | |||
| ) | |||
| @jit.trace(capture_as_const=True) | |||
| def f(x): | |||
| qnet.eval() | |||
| return qnet(x) | |||
| f(inputs) | |||
| file = io.BytesIO() | |||
| f.dump(file, enable_nchw4=True) | |||
| file.seek(0) | |||
| dumped_outputs = cgtools.load_and_inference(file, [inputs])[0] | |||
| np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6) | |||