GitOrigin-RevId: f6dbd1f4c0
tags/v1.1.0
| @@ -45,8 +45,6 @@ def conv_bias_activation( | |||||
| :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | ||||
| 'CROSS_CORRELATION' | 'CROSS_CORRELATION' | ||||
| :param dtype: support for ``np.dtype``, Default: np.int8 | :param dtype: support for ``np.dtype``, Default: np.int8 | ||||
| :param scale: scale if use quantization, Default: 0.0 | |||||
| :param zero_point: scale if use quantization quint8, Default: 0.0 | |||||
| :type compute_mode: string or | :type compute_mode: string or | ||||
| :class:`P.Convolution.ComputeMode`. | :class:`P.Convolution.ComputeMode`. | ||||
| :param compute_mode: when set to "DEFAULT", no special requirements will be | :param compute_mode: when set to "DEFAULT", no special requirements will be | ||||
| @@ -75,3 +73,63 @@ def conv_bias_activation( | |||||
| ) | ) | ||||
| (outputs,) = apply(op, inp, weight, bias) | (outputs,) = apply(op, inp, weight, bias) | ||||
| return outputs | return outputs | ||||
| def batch_conv_bias_activation( | |||||
| inp: Tensor, | |||||
| weight: Tensor, | |||||
| bias: Tensor, | |||||
| dtype=None, | |||||
| stride: Union[int, Tuple[int, int]] = 1, | |||||
| padding: Union[int, Tuple[int, int]] = 0, | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | |||||
| groups: int = 1, | |||||
| nonlinear_mode="IDENTITY", | |||||
| conv_mode="CROSS_CORRELATION", | |||||
| compute_mode="DEFAULT", | |||||
| ) -> Tensor: | |||||
| """ | |||||
| Batch convolution bias with activation operation, only for inference. | |||||
| :param inp: feature map of the convolution operation. | |||||
| :param weight: convolution kernel in batched way. | |||||
| :param bias: bias added to the result of convolution | |||||
| :param stride: stride of the 2D convolution operation. Default: 1 | |||||
| :param padding: size of the paddings added to the input on both sides of its spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
| :param dilation: dilation of the 2D convolution operation. Default: 1 | |||||
| :param groups: number of groups into which the input and output channels are divided, so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
| ``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
| and the shape of weight should be `(groups, out_channel // groups, | |||||
| in_channels // groups, height, width)`. | |||||
| :type conv_mode: string or :class:`P.Convolution.Mode`. | |||||
| :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||||
| 'CROSS_CORRELATION' | |||||
| :param dtype: support for ``np.dtype``, Default: np.int8 | |||||
| :type compute_mode: string or | |||||
| :class:`P.Convolution.ComputeMode`. | |||||
| :param compute_mode: when set to "DEFAULT", no special requirements will be | |||||
| placed on the precision of intermediate results. When set to "FLOAT32", | |||||
| "Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype. | |||||
| """ | |||||
| ph, pw = _pair(padding) | |||||
| sh, sw = _pair_nonzero(stride) | |||||
| dh, dw = _pair_nonzero(dilation) | |||||
| sparse_type = "DENSE" if groups == 1 else "GROUP" | |||||
| op = builtin.BatchConvBiasForward( | |||||
| stride_h=sh, | |||||
| stride_w=sw, | |||||
| pad_h=ph, | |||||
| pad_w=pw, | |||||
| dilate_h=dh, | |||||
| dilate_w=dw, | |||||
| dtype=dtype, | |||||
| format="NCHW", | |||||
| strategy=get_conv_execution_strategy(), | |||||
| nonlineMode=nonlinear_mode, | |||||
| mode=conv_mode, | |||||
| compute_mode=compute_mode, | |||||
| sparse=sparse_type, | |||||
| ) | |||||
| (outputs,) = apply(op, inp, weight, bias) | |||||
| return outputs | |||||
| @@ -9,6 +9,7 @@ | |||||
| from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | ||||
| from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d | from .adaptive_pooling import AdaptiveAvgPool2d, AdaptiveMaxPool2d | ||||
| from .batch_matmul_activation import BatchMatMulActivation | |||||
| from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | ||||
| from .concat import Concat | from .concat import Concat | ||||
| from .conv import Conv1d, Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | from .conv import Conv1d, Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | ||||
| @@ -0,0 +1,67 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import numpy as np | |||||
| from ..functional import matmul, relu | |||||
| from ..tensor import Parameter | |||||
| from . import init | |||||
| from .module import Module | |||||
| class BatchMatMulActivation(Module): | |||||
| r""" | |||||
| Batched MatMul with activation(only relu supported), no transpose anywhere. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| batch: int, | |||||
| in_features: int, | |||||
| out_features: int, | |||||
| bias: bool = True, | |||||
| nonlinear_mode="IDENTITY", | |||||
| **kwargs | |||||
| ): | |||||
| super().__init__(**kwargs) | |||||
| self.batch = batch | |||||
| self.out_features = out_features | |||||
| self.in_features = in_features | |||||
| w_shape = (batch, out_features, in_features) | |||||
| self.weight = Parameter(np.zeros(w_shape, dtype=np.float32)) | |||||
| self.bias = None | |||||
| if bias: | |||||
| b_shape = (out_features,) | |||||
| self.bias = Parameter(np.zeros(b_shape, dtype=np.float32)) | |||||
| self.nonlinear_mode = nonlinear_mode | |||||
| self.reset_parameters() | |||||
| def _get_fanin(self): | |||||
| return self.in_features | |||||
| def reset_parameters(self) -> None: | |||||
| fanin = self._get_fanin() | |||||
| std = np.sqrt(1 / fanin) | |||||
| init.normal_(self.weight, 0.0, std) | |||||
| if self.bias is not None: | |||||
| init.zeros_(self.bias) | |||||
| def _calc_linear(self, x, weight, bias): | |||||
| res = matmul(weight, x) | |||||
| if self.bias is not None: | |||||
| res += bias | |||||
| if self.nonlinear_mode == "RELU": | |||||
| res = relu(res) | |||||
| return res | |||||
| def forward(self, x): | |||||
| return self._calc_linear(x, self.weight, self.bias) | |||||
| def _module_info_string(self) -> str: | |||||
| return "batch={}, in_features={}, out_features={}, bias={}".format( | |||||
| self.batch, self.in_features, self.out_features, self.bias is not None | |||||
| ) | |||||
| @@ -5,6 +5,7 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from .batch_matmul_activation import BatchMatMulActivation | |||||
| from .concat import Concat | from .concat import Concat | ||||
| from .conv import Conv2d, ConvRelu2d | from .conv import Conv2d, ConvRelu2d | ||||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | from .conv_bn import ConvBn2d, ConvBnRelu2d | ||||
| @@ -0,0 +1,30 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from ...quantization.utils import fake_quant_bias | |||||
| from .. import batch_matmul_activation as Float | |||||
| from .module import QATModule | |||||
| class BatchMatMulActivation(Float.BatchMatMulActivation, QATModule): | |||||
| def forward(self, inp): | |||||
| w_qat = self.apply_quant_weight(self.weight) | |||||
| b_qat = fake_quant_bias(self.bias, inp, w_qat) | |||||
| return self.apply_quant_activation(self._calc_linear(inp, w_qat, b_qat)) | |||||
| @classmethod | |||||
| def from_float_module(cls, float_module: Float.BatchMatMulActivation): | |||||
| qat_module = cls( | |||||
| float_module.batch, | |||||
| float_module.in_features, | |||||
| float_module.out_features, | |||||
| float_module.bias is not None, | |||||
| ) | |||||
| qat_module.weight = float_module.weight | |||||
| qat_module.bias = float_module.bias | |||||
| return qat_module | |||||
| @@ -5,6 +5,7 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from .batch_matmul_activation import BatchMatMulActivation | |||||
| from .concat import Concat | from .concat import Concat | ||||
| from .conv import Conv2d, ConvRelu2d | from .conv import Conv2d, ConvRelu2d | ||||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | from .conv_bn import ConvBn2d, ConvBnRelu2d | ||||
| @@ -0,0 +1,76 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from typing import Tuple, Union | |||||
| import numpy as np | |||||
| from ... import module as Float | |||||
| from ...core.tensor import dtype | |||||
| from ...functional import expand_dims, squeeze | |||||
| from ...functional.quantized import batch_conv_bias_activation | |||||
| from ...tensor import Parameter | |||||
| from ..qat import batch_matmul_activation as QAT | |||||
| from .module import QuantizedModule | |||||
| class BatchMatMulActivation(Float.BatchMatMulActivation, QuantizedModule): | |||||
| def __init__( | |||||
| self, | |||||
| batch: int, | |||||
| in_features: int, | |||||
| out_features: int, | |||||
| bias: bool = True, | |||||
| nonlinear_mode="IDENTITY", | |||||
| dtype=None, | |||||
| **kwargs | |||||
| ): | |||||
| super().__init__(batch, in_features, out_features, bias, **kwargs) | |||||
| self.output_dtype = dtype | |||||
| def calc_bmm_quantized(self, inp): | |||||
| inp_scale = dtype.get_scale(inp.dtype) | |||||
| w_scale = dtype.get_scale(self.weight.dtype) | |||||
| bias_scale = inp_scale * w_scale | |||||
| inp = expand_dims(inp, [-1]) | |||||
| res = batch_conv_bias_activation( | |||||
| inp, | |||||
| self.weight, | |||||
| self.bias.astype(dtype.qint32(bias_scale)), | |||||
| dtype=self.output_dtype, | |||||
| stride=1, | |||||
| padding=0, | |||||
| dilation=1, | |||||
| groups=1, | |||||
| nonlinear_mode=self.nonlinear_mode, | |||||
| ) | |||||
| return squeeze(res, -1) | |||||
| @classmethod | |||||
| def from_qat_module(cls, qat_module: QAT.BatchMatMulActivation): | |||||
| output_dtype = qat_module.get_activation_dtype() | |||||
| qbmm = cls( | |||||
| qat_module.batch, | |||||
| qat_module.in_features, | |||||
| qat_module.out_features, | |||||
| qat_module.bias is not None, | |||||
| dtype=output_dtype, | |||||
| ) | |||||
| weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | |||||
| weight = expand_dims(weight, [-1, -2]) | |||||
| qbmm.weight = Parameter(weight.numpy()) | |||||
| if qat_module.bias is not None: | |||||
| bias = qat_module.bias.reshape((1, qbmm.out_features, 1, 1)) | |||||
| qbmm.bias = Parameter(bias.numpy()) | |||||
| else: | |||||
| qbmm.bias = Parameter( | |||||
| np.zeros((1, qbmm.out_features, 1, 1), dtype=np.float32) | |||||
| ) | |||||
| return qbmm | |||||
| def forward(self, inp): | |||||
| return self.calc_bmm_quantized(inp) | |||||
| @@ -20,6 +20,7 @@ from megengine import Parameter, Tensor, is_cuda_available, tensor | |||||
| from megengine.core._trace_option import use_symbolic_shape | from megengine.core._trace_option import use_symbolic_shape | ||||
| from megengine.core.autodiff.grad import Grad | from megengine.core.autodiff.grad import Grad | ||||
| from megengine.core.tensor.utils import make_shape_tuple | from megengine.core.tensor.utils import make_shape_tuple | ||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| def test_where(): | def test_where(): | ||||
| @@ -420,7 +421,9 @@ def test_nms(): | |||||
| np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32)) | np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32)) | ||||
| @pytest.mark.skip(reason="cuda does not support nchw int8") | |||||
| @pytest.mark.skipif( | |||||
| get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8" | |||||
| ) | |||||
| def test_conv_bias(): | def test_conv_bias(): | ||||
| inp_scale = 1.5 | inp_scale = 1.5 | ||||
| w_scale = 2.5 | w_scale = 2.5 | ||||
| @@ -446,7 +449,7 @@ def test_conv_bias(): | |||||
| nonlinear_mode="IDENTITY", | nonlinear_mode="IDENTITY", | ||||
| ): | ): | ||||
| inp_v = np.random.normal(size=(N, IC, IH, IW)) | inp_v = np.random.normal(size=(N, IC, IH, IW)) | ||||
| w_v = np.random.normal(size=(OC, IC, KW, KW)) | |||||
| w_v = np.random.normal(size=(OC, IC, KH, KW)) | |||||
| b_v = np.random.normal(size=(1, OC, 1, 1)) | b_v = np.random.normal(size=(1, OC, 1, 1)) | ||||
| inp_scale = dtype.get_scale(inp_dtype) | inp_scale = dtype.get_scale(inp_dtype) | ||||
| w_scale = dtype.get_scale(w_dtype) | w_scale = dtype.get_scale(w_dtype) | ||||
| @@ -486,13 +489,12 @@ def test_conv_bias(): | |||||
| inp = convert_to_nchw4(inp) | inp = convert_to_nchw4(inp) | ||||
| w = convert_to_nchw4(w) | w = convert_to_nchw4(w) | ||||
| b = convert_to_nchw4(b) | b = convert_to_nchw4(b) | ||||
| return F.nn.conv_bias_activation( | |||||
| return F.quantized.conv_bias_activation( | |||||
| inp, | inp, | ||||
| w, | w, | ||||
| b, | b, | ||||
| stride=(SH, SW), | stride=(SH, SW), | ||||
| padding=(PH, PW), | padding=(PH, PW), | ||||
| format=format, | |||||
| dtype=out_dtype, | dtype=out_dtype, | ||||
| nonlinear_mode=nonlinear_mode, | nonlinear_mode=nonlinear_mode, | ||||
| ) | ) | ||||
| @@ -522,6 +524,59 @@ def test_conv_bias(): | |||||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") | run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") | ||||
| @pytest.mark.skipif( | |||||
| get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda" | |||||
| ) | |||||
| def test_batch_conv_bias(): | |||||
| inp_scale = 1.5 | |||||
| w_scale = 2.5 | |||||
| outp_scale = 1.5 | |||||
| inp_dtype = dtype.qint8(inp_scale) | |||||
| w_dtype = dtype.qint8(w_scale) | |||||
| b_dtype = dtype.qint32(inp_scale * w_scale) | |||||
| out_dtype = dtype.qint8(outp_scale) | |||||
| def run( | |||||
| N, IC, OC, IH, IW, KH, KW, PH, PW, SH, SW, has_bias=True, | |||||
| ): | |||||
| inp_v = np.random.normal(size=(N, IC, IH, IW)) | |||||
| w_v = np.random.normal(size=(N, OC, IC, KH, KW)) | |||||
| b_v = np.random.normal(size=(1, OC, 1, 1)) | |||||
| inp_scale = dtype.get_scale(inp_dtype) | |||||
| w_scale = dtype.get_scale(w_dtype) | |||||
| b_scale = dtype.get_scale(b_dtype) | |||||
| inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype) | |||||
| wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype) | |||||
| bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype) | |||||
| inp_int8 = tensor(inpv, dtype=inp_dtype) | |||||
| w_int8 = Parameter(wv, dtype=w_dtype) | |||||
| b_int32 = Parameter(bv, dtype=b_dtype) | |||||
| inp_fp32 = inp_int8.astype("float32") | |||||
| w_fp32 = w_int8.astype("float32") | |||||
| b_fp32 = b_int32.astype("float32") | |||||
| def run_batch_conv_bias(inp, w, b): | |||||
| b = b if has_bias else Parameter(np.zeros_like(b.numpy())) | |||||
| result = F.quantized.batch_conv_bias_activation( | |||||
| inp, w, b, stride=(SH, SW), padding=(PH, PW), dtype=out_dtype, | |||||
| ) | |||||
| return result.astype("float32") | |||||
| expected = F.conv2d(inp_fp32, w_fp32[0], b_fp32 if has_bias else None)[0] | |||||
| expected = expected.astype(out_dtype).astype("float32") | |||||
| expected = F.flatten(expected) | |||||
| result = run_batch_conv_bias(inp_int8, w_int8, b_int32) | |||||
| result = F.flatten(result) | |||||
| np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale) | |||||
| run(1, 4, 4, 5, 5, 3, 3, 0, 0, 1, 1, True) | |||||
| def test_zero_stride_numpy_array(): | def test_zero_stride_numpy_array(): | ||||
| inp = np.random.randn(3, 224, 224).astype(np.float32) | inp = np.random.randn(3, 224, 224).astype(np.float32) | ||||
| inp = inp[np.newaxis, :] | inp = inp[np.newaxis, :] | ||||
| @@ -1,9 +1,15 @@ | |||||
| import io | |||||
| from itertools import product | from itertools import product | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| from megengine import tensor | |||||
| import megengine.utils.comp_graph_tools as cgtools | |||||
| from megengine import jit, tensor | |||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.functional import expand_dims | |||||
| from megengine.module import ( | from megengine.module import ( | ||||
| BatchMatMulActivation, | |||||
| Conv2d, | Conv2d, | ||||
| ConvBn2d, | ConvBn2d, | ||||
| ConvRelu2d, | ConvRelu2d, | ||||
| @@ -11,7 +17,12 @@ from megengine.module import ( | |||||
| Module, | Module, | ||||
| QuantStub, | QuantStub, | ||||
| ) | ) | ||||
| from megengine.quantization.quantize import disable_fake_quant, quantize_qat | |||||
| from megengine.quantization.quantize import ( | |||||
| disable_fake_quant, | |||||
| enable_fake_quant, | |||||
| quantize, | |||||
| quantize_qat, | |||||
| ) | |||||
| def test_qat_convbn2d(): | def test_qat_convbn2d(): | ||||
| @@ -88,3 +99,107 @@ def test_qat_conv(): | |||||
| qat_net.eval() | qat_net.eval() | ||||
| qat_outputs = qat_net(inputs) | qat_outputs = qat_net(inputs) | ||||
| np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | ||||
| @pytest.mark.skipif( | |||||
| get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda" | |||||
| ) | |||||
| def test_qat_batchmatmul_activation(): | |||||
| batch = 4 | |||||
| in_features = 8 | |||||
| out_features = 4 | |||||
| class TestNet(Module): | |||||
| def __init__(self, bias): | |||||
| super().__init__() | |||||
| self.quant = QuantStub() | |||||
| self.dequant = DequantStub() | |||||
| self.batch_mm = BatchMatMulActivation( | |||||
| batch, in_features, out_features, bias=bias | |||||
| ) | |||||
| def forward(self, inp): | |||||
| out = self.quant(inp) | |||||
| out = self.batch_mm(out) | |||||
| out = self.dequant(out) | |||||
| return out | |||||
| inputs = tensor( | |||||
| np.random.randn(batch, in_features, out_features).astype(np.float32) | |||||
| ) | |||||
| for bias in (True, False): | |||||
| net = TestNet(bias) | |||||
| net.train() | |||||
| qat_net = quantize_qat(net, inplace=False) | |||||
| disable_fake_quant(qat_net) | |||||
| normal_outputs = net(inputs) | |||||
| qat_outputs = qat_net(inputs) | |||||
| np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||||
| net.eval() | |||||
| normal_outputs = net(inputs) | |||||
| qat_net.eval() | |||||
| qat_outputs = qat_net(inputs) | |||||
| np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||||
| @pytest.mark.skip(reason="FIXME: abnormal exit") | |||||
| def test_quantize_batchmatmul_activation(): | |||||
| batch = 4 | |||||
| in_features = 8 | |||||
| out_features = 4 | |||||
| class TestNet(Module): | |||||
| def __init__(self, bias): | |||||
| super().__init__() | |||||
| self.quant = QuantStub() | |||||
| self.dequant = DequantStub() | |||||
| self.batch_mm = BatchMatMulActivation( | |||||
| batch, in_features, out_features, bias=bias | |||||
| ) | |||||
| def forward(self, inp): | |||||
| out = self.quant(inp) | |||||
| out = self.batch_mm(out) | |||||
| out = expand_dims(out, -1) | |||||
| out = self.dequant(out) | |||||
| return out | |||||
| inputs = tensor( | |||||
| np.random.randn(batch, in_features, out_features).astype(np.float32) | |||||
| ) | |||||
| for bias in (True, False): | |||||
| net = TestNet(bias) | |||||
| net.train() | |||||
| qat_net = quantize_qat(net, inplace=False) | |||||
| disable_fake_quant(qat_net) | |||||
| normal_outputs = net(inputs) | |||||
| qat_outputs = qat_net(inputs) | |||||
| np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||||
| net.eval() | |||||
| normal_outputs = net(inputs) | |||||
| qat_net.eval() | |||||
| qat_outputs = qat_net(inputs) | |||||
| np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||||
| enable_fake_quant(qat_net) | |||||
| qat_outputs = qat_net(inputs) | |||||
| qnet = quantize(qat_net, inplace=False) | |||||
| qnet.eval() | |||||
| quantize_outputs = qnet(inputs) | |||||
| np.testing.assert_allclose( | |||||
| qat_outputs.numpy(), quantize_outputs.numpy(), atol=1e-6 | |||||
| ) | |||||
| @jit.trace(capture_as_const=True) | |||||
| def f(x): | |||||
| qnet.eval() | |||||
| return qnet(x) | |||||
| f(inputs) | |||||
| file = io.BytesIO() | |||||
| f.dump(file, enable_nchw4=True) | |||||
| file.seek(0) | |||||
| dumped_outputs = cgtools.load_and_inference(file, [inputs])[0] | |||||
| np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6) | |||||