| @@ -12,11 +12,13 @@ from .conv import ( | |||
| ConvRelu2d, | |||
| ConvTranspose2d, | |||
| ConvTranspose3d, | |||
| ConvTransposeRelu2d, | |||
| DeformableConv2d, | |||
| LocalConv2d, | |||
| RegionRestrictedConv, | |||
| ) | |||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
| from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d | |||
| from .deformable_psroi_pooling import DeformablePSROIPooling | |||
| from .dropout import Dropout | |||
| from .elemwise import Elemwise | |||
| @@ -773,6 +773,15 @@ class ConvRelu2d(Conv2d): | |||
| return relu(self.calc_conv(inp, self.weight, self.bias)) | |||
| class ConvTransposeRelu2d(ConvTranspose2d): | |||
| r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d` and :func:`~.relu`. | |||
| Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeRelu2d` using :func:`~.quantize.quantize_qat`. | |||
| """ | |||
| def forward(self, inp): | |||
| return relu(self.calc_conv_transpose2d(inp, self.weight, self.bias)) | |||
| class DeformableConv2d(_ConvNd): | |||
| r"""Deformable Convolution. | |||
| @@ -0,0 +1,62 @@ | |||
| from typing import Tuple, Union | |||
| from ..functional import relu | |||
| from .batchnorm import BatchNorm2d | |||
| from .conv import ConvTranspose2d | |||
| from .module import Module | |||
| class _ConvTransposeBnActivation2d(Module): | |||
| def __init__( | |||
| self, | |||
| in_channels: int, | |||
| out_channels: int, | |||
| kernel_size: Union[int, Tuple[int, int]], | |||
| stride: Union[int, Tuple[int, int]] = 1, | |||
| padding: Union[int, Tuple[int, int]] = 0, | |||
| output_padding: Union[int, Tuple[int, int]] = 0, | |||
| dilation: Union[int, Tuple[int, int]] = 1, | |||
| groups: int = 1, | |||
| bias: bool = True, | |||
| conv_mode: str = "cross_correlation", | |||
| compute_mode: str = "default", | |||
| eps=1e-5, | |||
| momentum=0.9, | |||
| affine=True, | |||
| track_running_stats=True, | |||
| **kwargs | |||
| ): | |||
| super().__init__(**kwargs) | |||
| self.conv_transpose2d = ConvTranspose2d( | |||
| in_channels, | |||
| out_channels, | |||
| kernel_size, | |||
| stride, | |||
| padding, | |||
| output_padding, | |||
| dilation, | |||
| groups, | |||
| bias, | |||
| conv_mode, | |||
| compute_mode, | |||
| **kwargs, | |||
| ) | |||
| self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | |||
| class ConvTransposeBn2d(_ConvTransposeBnActivation2d): | |||
| r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d` and :class:`~.module.BatchNorm2d`. | |||
| Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeBn2d` using:func:`~.quantize.quantize_qat`. | |||
| """ | |||
| def forward(self, inp): | |||
| return self.bn(self.conv_transpose2d(inp)) | |||
| class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d): | |||
| r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu`. | |||
| Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeBnRelu2d` using :func:`~.quantize.quantize_qat`. | |||
| """ | |||
| def forward(self, inp): | |||
| return relu(self.bn(self.conv_transpose2d(inp))) | |||
| @@ -1,7 +1,8 @@ | |||
| from .batch_matmul_activation import BatchMatMulActivation | |||
| from .concat import Concat | |||
| from .conv import Conv2d, ConvRelu2d, ConvTranspose2d | |||
| from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d | |||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
| from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d | |||
| from .elemwise import Elemwise | |||
| from .linear import Linear | |||
| from .module import QATModule | |||
| @@ -59,8 +59,8 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule): | |||
| def calc_conv_transpose2d_qat(self, inp): | |||
| w_qat = self.apply_quant_weight(self.weight) | |||
| b_qat = self.apply_quant_bias(self.bias, inp, w_qat) | |||
| conv = self.calc_conv_transpose2d(inp, w_qat, b_qat) | |||
| return conv | |||
| conv_transpose2d = self.calc_conv_transpose2d(inp, w_qat, b_qat) | |||
| return conv_transpose2d | |||
| @classmethod | |||
| def from_float_module(cls, float_module: Float.ConvTranspose2d): | |||
| @@ -88,3 +88,12 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule): | |||
| def forward(self, inp): | |||
| return self.apply_quant_activation(self.calc_conv_transpose2d_qat(inp)) | |||
| class ConvTransposeRelu2d(ConvTranspose2d): | |||
| r"""A :class:`~.QATModule` include :class:`~.module.ConvTranspose2d` and :func:`~.relu` with QAT support. | |||
| Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`. | |||
| """ | |||
| def forward(self, inp): | |||
| return self.apply_quant_activation(F.relu(self.calc_conv_transpose2d_qat(inp))) | |||
| @@ -0,0 +1,163 @@ | |||
| from ...functional import ones, relu, sqrt, sum, zeros | |||
| from .. import conv_transpose_bn as Float | |||
| from .module import QATModule | |||
| class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule): | |||
| def get_batch_mean_var(self, inp): | |||
| def _sum_channel(inp, axis=0, keepdims=True): | |||
| if isinstance(axis, int): | |||
| out = sum(inp, axis=axis, keepdims=keepdims) | |||
| elif isinstance(axis, tuple): | |||
| for idx, elem in enumerate(axis): | |||
| out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) | |||
| return out | |||
| sum1 = _sum_channel(inp, (0, 2, 3)) | |||
| sum2 = _sum_channel(inp ** 2, (0, 2, 3)) | |||
| reduce_size = inp.size / inp.shape[1] | |||
| batch_mean = sum1 / reduce_size | |||
| batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size | |||
| return batch_mean, batch_var | |||
| def fold_weight_bias(self, bn_mean, bn_var): | |||
| # get fold bn conv_transpose2d param | |||
| gamma = self.bn.weight | |||
| if gamma is None: | |||
| gamma = ones((self.bn.num_features), dtype="float32") | |||
| gamma = gamma.reshape(1, -1, 1, 1) | |||
| beta = self.bn.bias | |||
| if beta is None: | |||
| beta = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||
| if bn_mean is None: | |||
| bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||
| if bn_var is None: | |||
| bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||
| conv_transpose2d_bias = self.conv_transpose2d.bias | |||
| if conv_transpose2d_bias is None: | |||
| conv_transpose2d_bias = zeros( | |||
| self.conv_transpose2d._infer_bias_shape(), dtype="float32" | |||
| ) | |||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
| scale_factor = gamma * bn_istd | |||
| if self.conv_transpose2d.groups == 1: | |||
| w_fold = self.conv_transpose2d.weight * scale_factor.reshape(-1, 1, 1, 1) | |||
| else: | |||
| w_fold = self.conv_transpose2d.weight * scale_factor.reshape( | |||
| self.conv_transpose2d.groups, -1, 1, 1, 1 | |||
| ) | |||
| w_fold = self.apply_quant_weight(w_fold) | |||
| b_fold = beta + gamma * (conv_transpose2d_bias - bn_mean) * bn_istd | |||
| return w_fold, b_fold | |||
| def update_running_mean_and_running_var( | |||
| self, bn_mean, bn_var, num_elements_per_channel | |||
| ): | |||
| # update running mean and running var. no grad, use unbiased bn var | |||
| bn_mean = bn_mean.detach() | |||
| bn_var = ( | |||
| bn_var.detach() * num_elements_per_channel / (num_elements_per_channel - 1) | |||
| ) | |||
| exponential_average_factor = 1 - self.bn.momentum | |||
| self.bn.running_mean *= self.bn.momentum | |||
| self.bn.running_mean += exponential_average_factor * bn_mean | |||
| self.bn.running_var *= self.bn.momentum | |||
| self.bn.running_var += exponential_average_factor * bn_var | |||
| def calc_conv_transpose2d_bn_qat(self, inp, approx=True): | |||
| if self.training and not approx: | |||
| conv_transpose2d = self.conv_transpose2d(inp) | |||
| bn_mean, bn_var = self.get_batch_mean_var(conv_transpose2d) | |||
| num_elements_per_channel = conv_transpose2d.size / conv_transpose2d.shape[1] | |||
| self.update_running_mean_and_running_var( | |||
| bn_mean, bn_var, num_elements_per_channel | |||
| ) | |||
| else: | |||
| bn_mean, bn_var = self.bn.running_mean, self.bn.running_var | |||
| # get gamma and beta in BatchNorm | |||
| gamma = self.bn.weight | |||
| if gamma is None: | |||
| gamma = ones((self.bn.num_features), dtype="float32") | |||
| gamma = gamma.reshape(1, -1, 1, 1) | |||
| beta = self.bn.bias | |||
| if beta is None: | |||
| beta = zeros((self.bn.num_features), dtype="float32") | |||
| beta = beta.reshape(1, -1, 1, 1) | |||
| # conv_transpose2d_bias | |||
| conv_transpose2d_bias = self.conv_transpose2d.bias | |||
| if conv_transpose2d_bias is None: | |||
| conv_transpose2d_bias = zeros( | |||
| self.conv_transpose2d._infer_bias_shape(), dtype="float32" | |||
| ) | |||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
| scale_factor = gamma * bn_istd | |||
| if self.conv_transpose2d.groups == 1: | |||
| w_fold = self.conv_transpose2d.weight * scale_factor.reshape(1, -1, 1, 1) | |||
| else: | |||
| w_fold = self.conv_transpose2d.weight * scale_factor.reshape( | |||
| self.conv_transpose2d.groups, 1, -1, 1, 1 | |||
| ) | |||
| b_fold = None | |||
| if not (self.training and approx): | |||
| b_fold = beta + gamma * (conv_transpose2d_bias - bn_mean) * bn_istd | |||
| w_qat = self.apply_quant_weight(w_fold) | |||
| b_qat = self.apply_quant_bias(b_fold, inp, w_qat) | |||
| conv_transpose2d = self.conv_transpose2d.calc_conv_transpose2d( | |||
| inp, w_qat, b_qat | |||
| ) | |||
| if not (self.training and approx): | |||
| return conv_transpose2d | |||
| # rescale conv_transpose2d to get original conv_transpose2d output | |||
| orig_conv_transpose2d = conv_transpose2d / scale_factor.reshape(1, -1, 1, 1) | |||
| if self.conv_transpose2d.bias is not None: | |||
| orig_conv_transpose2d = orig_conv_transpose2d + self.conv_transpose2d.bias | |||
| # calculate batch norm | |||
| conv_transpose2d = self.bn(orig_conv_transpose2d) | |||
| return conv_transpose2d | |||
| @classmethod | |||
| def from_float_module(cls, float_module: Float._ConvTransposeBnActivation2d): | |||
| qat_module = cls( | |||
| float_module.conv_transpose2d.in_channels, | |||
| float_module.conv_transpose2d.out_channels, | |||
| float_module.conv_transpose2d.kernel_size, | |||
| float_module.conv_transpose2d.stride, | |||
| float_module.conv_transpose2d.padding, | |||
| float_module.conv_transpose2d.output_padding, | |||
| float_module.conv_transpose2d.dilation, | |||
| float_module.conv_transpose2d.groups, | |||
| float_module.conv_transpose2d.bias is not None, | |||
| float_module.conv_transpose2d.conv_mode, | |||
| float_module.conv_transpose2d.compute_mode, | |||
| name=float_module.name, | |||
| ) | |||
| qat_module.conv_transpose2d.weight = float_module.conv_transpose2d.weight | |||
| qat_module.conv_transpose2d.bias = float_module.conv_transpose2d.bias | |||
| qat_module.bn = float_module.bn | |||
| return qat_module | |||
| class ConvTransposeBn2d(_ConvTransposeBnActivation2d): | |||
| r"""A fused :class:`~.QATModule` including :class:`~.module.ConvTranspose2d` and :class:`~.module.BatchNorm2d` with QAT support. | |||
| Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`. | |||
| """ | |||
| def forward(self, inp): | |||
| return self.apply_quant_activation(self.calc_conv_transpose2d_bn_qat(inp)) | |||
| class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d): | |||
| r"""A fused :class:`~.QATModule` including :class:`~.module.ConvTranspose2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu` with QAT support. | |||
| Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`. | |||
| """ | |||
| def forward(self, inp): | |||
| return self.apply_quant_activation(relu(self.calc_conv_transpose2d_bn_qat(inp))) | |||
| @@ -1,7 +1,8 @@ | |||
| from .batch_matmul_activation import BatchMatMulActivation | |||
| from .concat import Concat | |||
| from .conv import Conv2d, ConvRelu2d, ConvTranspose2d | |||
| from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d | |||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
| from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d | |||
| from .elemwise import Elemwise | |||
| from .linear import Linear | |||
| from .module import QuantizedModule | |||
| @@ -178,7 +178,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): | |||
| :class:`~.QATModule` instance. | |||
| """ | |||
| output_dtype = qat_module.get_activation_dtype() | |||
| qconv = cls( | |||
| qconv_transpose2d = cls( | |||
| qat_module.in_channels, | |||
| qat_module.out_channels, | |||
| qat_module.kernel_size, | |||
| @@ -194,15 +194,19 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): | |||
| name=qat_module.name, | |||
| ) | |||
| weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | |||
| qconv.weight = Parameter(weight.numpy(), name=qat_module.weight.name) | |||
| qconv.bias = ( | |||
| qconv_transpose2d.weight = Parameter( | |||
| weight.numpy(), name=qat_module.weight.name | |||
| ) | |||
| qconv_transpose2d.bias = ( | |||
| Parameter(qat_module.bias.numpy(), name=qat_module.bias.name) | |||
| if qat_module.bias is not None | |||
| else None | |||
| ) | |||
| return qconv | |||
| return qconv_transpose2d | |||
| def calc_conv_transpose2d_quantized(self, inp, nonlinear_mode): | |||
| assert nonlinear_mode == "identity", "nonlinear_mode shoule be 'identity'" | |||
| def calc_conv_transpose2d_quantized(self, inp): | |||
| if self.bias is not None: | |||
| inp_scale = dtype.get_scale(inp.dtype) | |||
| w_scale = dtype.get_scale(self.weight.dtype) | |||
| @@ -225,4 +229,11 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): | |||
| ) | |||
| def forward(self, inp): | |||
| return self.calc_conv_transpose2d_quantized(inp) | |||
| return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="identity") | |||
| class ConvTransposeRelu2d(ConvTranspose2d): | |||
| r"""Quantized version of :class:`~.qat.ConvTransposeRelu2d`.""" | |||
| def forward(self, inp): | |||
| return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="relu") | |||
| @@ -0,0 +1,53 @@ | |||
| from ...tensor import Parameter | |||
| from ..qat import conv_transpose_bn as QAT | |||
| from .conv import ConvTranspose2d | |||
| class _ConvTransposeBnActivation2d(ConvTranspose2d): | |||
| r"""Applies a 2D deconvolution over a quantized input tensor, used for inference only. | |||
| """ | |||
| @classmethod | |||
| def from_qat_module(cls, qat_module: QAT._ConvTransposeBnActivation2d): | |||
| r""" | |||
| Return a :class:`~.QuantizedModule` instance converted from a | |||
| :class:`~.QATModule` instance. | |||
| """ | |||
| output_dtype = qat_module.get_activation_dtype() | |||
| qconv_transpose2d = cls( | |||
| qat_module.conv_transpose2d.in_channels, | |||
| qat_module.conv_transpose2d.out_channels, | |||
| qat_module.conv_transpose2d.kernel_size, | |||
| qat_module.conv_transpose2d.stride, | |||
| qat_module.conv_transpose2d.padding, | |||
| qat_module.conv_transpose2d.output_padding, | |||
| qat_module.conv_transpose2d.dilation, | |||
| qat_module.conv_transpose2d.groups, | |||
| dtype=output_dtype, | |||
| name=qat_module.name, | |||
| ) | |||
| w_fold, b_fold = qat_module.fold_weight_bias( | |||
| qat_module.bn.running_mean, qat_module.bn.running_var | |||
| ) | |||
| weight = w_fold.astype(qat_module.get_weight_dtype()) | |||
| qconv_transpose2d.weight = Parameter( | |||
| weight.numpy(), name=qat_module.conv_transpose2d.weight.name | |||
| ) | |||
| qconv_transpose2d.bias = Parameter(b_fold.numpy()) | |||
| if qat_module.conv_transpose2d.bias is not None: | |||
| qconv_transpose2d.bias.name = qat_module.conv_transpose2d.bias.name | |||
| return qconv_transpose2d | |||
| class ConvTransposeBn2d(_ConvTransposeBnActivation2d): | |||
| r"""Quantized version of :class:`~.qat.ConvTransposeBn2d`.""" | |||
| def forward(self, inp): | |||
| return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="identity") | |||
| class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d): | |||
| r"""Quantized version of :class:`~.qat.ConvTransposeBnRelu2d`.""" | |||
| def forward(self, inp): | |||
| return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="relu") | |||
| @@ -1,48 +1,70 @@ | |||
| from copy import deepcopy | |||
| from ..functional import ones, sqrt, zeros | |||
| from ..module import BatchNorm2d, Conv2d, ConvBn2d, ConvBnRelu2d, ConvRelu2d, ReLU | |||
| from ..module import ( | |||
| BatchNorm2d, | |||
| Conv2d, | |||
| ConvBn2d, | |||
| ConvBnRelu2d, | |||
| ConvRelu2d, | |||
| ConvTranspose2d, | |||
| ConvTransposeBn2d, | |||
| ConvTransposeBnRelu2d, | |||
| ConvTransposeRelu2d, | |||
| ReLU, | |||
| ) | |||
| from ..tensor import Parameter | |||
| _MAP_TO_FUSED_MODULE = { | |||
| (Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d, | |||
| (Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d, | |||
| (ConvTranspose2d, BatchNorm2d, ReLU, False): ConvTransposeRelu2d, | |||
| (ConvTranspose2d, BatchNorm2d, ReLU, True): ConvTransposeBnRelu2d, | |||
| (Conv2d, BatchNorm2d, False): Conv2d, | |||
| (Conv2d, BatchNorm2d, True): ConvBn2d, | |||
| (Conv2d, ReLU): ConvRelu2d, | |||
| (ConvTranspose2d, BatchNorm2d, False): ConvTranspose2d, | |||
| (ConvTranspose2d, BatchNorm2d, True): ConvTransposeBn2d, | |||
| (ConvTranspose2d, ReLU): ConvTransposeRelu2d, | |||
| } | |||
| def fold_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5): | |||
| # get fold bn conv param | |||
| def fold_weight_bias( | |||
| weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False | |||
| ): | |||
| shape = (1, -1, 1, 1) | |||
| if transpose: | |||
| shape = (-1, 1, 1, 1) | |||
| kernel_shape = weight.shape | |||
| if len(kernel_shape) == 5: | |||
| groups, num_features = kernel_shape[0], kernel_shape[1] | |||
| else: | |||
| groups, num_features = 1, kernel_shape[0] | |||
| out_channels = groups * num_features | |||
| if gamma is None: | |||
| gamma = ones((num_features), dtype="float32") | |||
| gamma = ones((out_channels,), dtype="float32") | |||
| gamma = gamma.reshape(1, -1, 1, 1) | |||
| if beta is None: | |||
| beta = zeros((num_features), dtype="float32") | |||
| beta = zeros((out_channels,), dtype="float32") | |||
| beta = beta.reshape(1, -1, 1, 1) | |||
| if bn_mean is None: | |||
| bn_mean = zeros((1, num_features, 1, 1), dtype="float32") | |||
| bn_mean = zeros((1, out_channels, 1, 1), dtype="float32") | |||
| if bn_var is None: | |||
| bn_var = ones((1, num_features, 1, 1), dtype="float32") | |||
| bn_var = ones((1, out_channels, 1, 1), dtype="float32") | |||
| if bias is None: | |||
| bias = zeros((1, num_features, 1, 1), dtype="float32") | |||
| bias = zeros((1, out_channels, 1, 1), dtype="float32") | |||
| bn_istd = 1.0 / sqrt(bn_var + eps) | |||
| scale_factor = gamma * bn_istd | |||
| if groups == 1: | |||
| w_fold = weight * scale_factor.reshape(-1, 1, 1, 1) | |||
| w_fold = weight * scale_factor.reshape(*shape) | |||
| else: | |||
| w_fold = weight * scale_factor.reshape(groups, -1, 1, 1, 1) | |||
| w_fold = weight * scale_factor.reshape(groups, *shape) | |||
| b_fold = beta + gamma * (bias - bn_mean) * bn_istd | |||
| return w_fold, b_fold | |||
| @@ -84,3 +106,55 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): | |||
| module.bn = deepcopy(bn) | |||
| new_conv.training = conv.training | |||
| return module | |||
| def fuse_conv_transpose2d_bn_relu_module( | |||
| conv_transpose2d: ConvTranspose2d, bn: BatchNorm2d, relu: ReLU | |||
| ): | |||
| module_key = tuple([type(m) for m in [conv_transpose2d, bn, relu] if m]) | |||
| if bn: | |||
| assert ( | |||
| conv_transpose2d.training == bn.training | |||
| ), "ConvTranspose2d and BN both must be in the same mode (train or eval)." | |||
| assert ( | |||
| bn.num_features == conv_transpose2d.out_channels | |||
| ), "Output channel of ConvTranspose2d must match num_features of BatchNorm2d" | |||
| module_key = module_key + (conv_transpose2d.training,) | |||
| module = _MAP_TO_FUSED_MODULE[module_key]( | |||
| in_channels=conv_transpose2d.in_channels, | |||
| out_channels=conv_transpose2d.out_channels, | |||
| kernel_size=conv_transpose2d.kernel_size, | |||
| stride=conv_transpose2d.stride, | |||
| padding=conv_transpose2d.padding, | |||
| output_padding=conv_transpose2d.output_padding, | |||
| dilation=conv_transpose2d.dilation, | |||
| groups=conv_transpose2d.groups, | |||
| bias=conv_transpose2d.bias is not None, | |||
| conv_mode=conv_transpose2d.conv_mode, | |||
| compute_mode=conv_transpose2d.compute_mode, | |||
| name=conv_transpose2d.name, | |||
| ) | |||
| new_conv_transpose2d = ( | |||
| module | |||
| if bn is None or not conv_transpose2d.training | |||
| else module.conv_transpose2d | |||
| ) | |||
| weight, bias = conv_transpose2d.weight, conv_transpose2d.bias | |||
| if not conv_transpose2d.training and bn is not None: | |||
| weight, bias = fold_weight_bias( | |||
| weight, | |||
| bias, | |||
| bn.weight, | |||
| bn.bias, | |||
| bn.running_mean, | |||
| bn.running_var, | |||
| bn.eps, | |||
| transpose=False, | |||
| ) | |||
| new_conv_transpose2d.weight = Parameter(weight) | |||
| if bias is not None: | |||
| new_conv_transpose2d.bias = Parameter(bias) | |||
| if bn is not None and conv_transpose2d.training: | |||
| module.bn = deepcopy(bn) | |||
| new_conv_transpose2d.training = conv_transpose2d.training | |||
| return module | |||
| @@ -5,7 +5,9 @@ import numpy as np | |||
| import pytest | |||
| import megengine.utils.comp_graph_tools as cgtools | |||
| from megengine import jit, tensor | |||
| from megengine import jit | |||
| from megengine import module as M | |||
| from megengine import tensor | |||
| from megengine.device import get_device_count | |||
| from megengine.functional import expand_dims | |||
| from megengine.module import ( | |||
| @@ -14,6 +16,8 @@ from megengine.module import ( | |||
| ConvBn2d, | |||
| ConvRelu2d, | |||
| ConvTranspose2d, | |||
| ConvTransposeBn2d, | |||
| ConvTransposeRelu2d, | |||
| DequantStub, | |||
| Module, | |||
| QuantStub, | |||
| @@ -34,6 +38,49 @@ def test_qat_convbn2d(): | |||
| module = ConvBn2d( | |||
| in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||
| ) | |||
| M.init.normal_(module.bn.weight) | |||
| M.init.normal_(module.bn.bias) | |||
| module.train() | |||
| qat_module = quantize_qat(module, inplace=False) | |||
| disable_fake_quant(qat_module) | |||
| inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
| normal_outputs = module(inputs) | |||
| qat_outputs = qat_module(inputs) | |||
| np.testing.assert_allclose( | |||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||
| ) | |||
| np.testing.assert_allclose( | |||
| module.bn.running_mean.numpy(), | |||
| qat_module.bn.running_mean.numpy(), | |||
| atol=5e-8, | |||
| ) | |||
| np.testing.assert_allclose( | |||
| module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7, | |||
| ) | |||
| module.eval() | |||
| normal_outputs = module(inputs) | |||
| qat_module.eval() | |||
| qat_outputs = qat_module(inputs) | |||
| np.testing.assert_allclose( | |||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||
| ) | |||
| def test_qat_convtransposebn2d(): | |||
| in_channels = 32 | |||
| out_channels = 64 | |||
| kernel_size = 3 | |||
| for groups, bias in product([1, 4], [True, False]): | |||
| module = ConvTransposeBn2d( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| kernel_size=kernel_size, | |||
| output_padding=0, | |||
| groups=groups, | |||
| bias=bias, | |||
| ) | |||
| M.init.normal_(module.bn.weight) | |||
| M.init.normal_(module.bn.bias) | |||
| module.train() | |||
| qat_module = quantize_qat(module, inplace=False) | |||
| disable_fake_quant(qat_module) | |||
| @@ -235,10 +282,14 @@ def test_qat_conv_transpose2d(): | |||
| self.conv = ConvTranspose2d( | |||
| in_channels, out_channels, kernel_size, bias=bias | |||
| ) | |||
| self.conv_transpose2d_relu = ConvTransposeRelu2d( | |||
| out_channels, in_channels, kernel_size, bias=bias | |||
| ) | |||
| def forward(self, inp): | |||
| out = self.quant(inp) | |||
| out = self.conv(out) | |||
| out = self.conv_transpose2d_relu(out) | |||
| out = self.dequant(out) | |||
| return out | |||
| @@ -250,10 +301,14 @@ def test_qat_conv_transpose2d(): | |||
| disable_fake_quant(qat_net) | |||
| normal_outputs = net(inputs) | |||
| qat_outputs = qat_net(inputs) | |||
| np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||
| np.testing.assert_allclose( | |||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6 | |||
| ) | |||
| net.eval() | |||
| normal_outputs = net(inputs) | |||
| qat_net.eval() | |||
| qat_outputs = qat_net(inputs) | |||
| np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||
| np.testing.assert_allclose( | |||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6 | |||
| ) | |||