| @@ -12,11 +12,13 @@ from .conv import ( | |||||
| ConvRelu2d, | ConvRelu2d, | ||||
| ConvTranspose2d, | ConvTranspose2d, | ||||
| ConvTranspose3d, | ConvTranspose3d, | ||||
| ConvTransposeRelu2d, | |||||
| DeformableConv2d, | DeformableConv2d, | ||||
| LocalConv2d, | LocalConv2d, | ||||
| RegionRestrictedConv, | RegionRestrictedConv, | ||||
| ) | ) | ||||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | from .conv_bn import ConvBn2d, ConvBnRelu2d | ||||
| from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d | |||||
| from .deformable_psroi_pooling import DeformablePSROIPooling | from .deformable_psroi_pooling import DeformablePSROIPooling | ||||
| from .dropout import Dropout | from .dropout import Dropout | ||||
| from .elemwise import Elemwise | from .elemwise import Elemwise | ||||
| @@ -773,6 +773,15 @@ class ConvRelu2d(Conv2d): | |||||
| return relu(self.calc_conv(inp, self.weight, self.bias)) | return relu(self.calc_conv(inp, self.weight, self.bias)) | ||||
| class ConvTransposeRelu2d(ConvTranspose2d): | |||||
| r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d` and :func:`~.relu`. | |||||
| Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeRelu2d` using :func:`~.quantize.quantize_qat`. | |||||
| """ | |||||
| def forward(self, inp): | |||||
| return relu(self.calc_conv_transpose2d(inp, self.weight, self.bias)) | |||||
| class DeformableConv2d(_ConvNd): | class DeformableConv2d(_ConvNd): | ||||
| r"""Deformable Convolution. | r"""Deformable Convolution. | ||||
| @@ -0,0 +1,62 @@ | |||||
| from typing import Tuple, Union | |||||
| from ..functional import relu | |||||
| from .batchnorm import BatchNorm2d | |||||
| from .conv import ConvTranspose2d | |||||
| from .module import Module | |||||
| class _ConvTransposeBnActivation2d(Module): | |||||
| def __init__( | |||||
| self, | |||||
| in_channels: int, | |||||
| out_channels: int, | |||||
| kernel_size: Union[int, Tuple[int, int]], | |||||
| stride: Union[int, Tuple[int, int]] = 1, | |||||
| padding: Union[int, Tuple[int, int]] = 0, | |||||
| output_padding: Union[int, Tuple[int, int]] = 0, | |||||
| dilation: Union[int, Tuple[int, int]] = 1, | |||||
| groups: int = 1, | |||||
| bias: bool = True, | |||||
| conv_mode: str = "cross_correlation", | |||||
| compute_mode: str = "default", | |||||
| eps=1e-5, | |||||
| momentum=0.9, | |||||
| affine=True, | |||||
| track_running_stats=True, | |||||
| **kwargs | |||||
| ): | |||||
| super().__init__(**kwargs) | |||||
| self.conv_transpose2d = ConvTranspose2d( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride, | |||||
| padding, | |||||
| output_padding, | |||||
| dilation, | |||||
| groups, | |||||
| bias, | |||||
| conv_mode, | |||||
| compute_mode, | |||||
| **kwargs, | |||||
| ) | |||||
| self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | |||||
| class ConvTransposeBn2d(_ConvTransposeBnActivation2d): | |||||
| r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d` and :class:`~.module.BatchNorm2d`. | |||||
| Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeBn2d` using:func:`~.quantize.quantize_qat`. | |||||
| """ | |||||
| def forward(self, inp): | |||||
| return self.bn(self.conv_transpose2d(inp)) | |||||
| class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d): | |||||
| r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu`. | |||||
| Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeBnRelu2d` using :func:`~.quantize.quantize_qat`. | |||||
| """ | |||||
| def forward(self, inp): | |||||
| return relu(self.bn(self.conv_transpose2d(inp))) | |||||
| @@ -1,7 +1,8 @@ | |||||
| from .batch_matmul_activation import BatchMatMulActivation | from .batch_matmul_activation import BatchMatMulActivation | ||||
| from .concat import Concat | from .concat import Concat | ||||
| from .conv import Conv2d, ConvRelu2d, ConvTranspose2d | |||||
| from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d | |||||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | from .conv_bn import ConvBn2d, ConvBnRelu2d | ||||
| from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d | |||||
| from .elemwise import Elemwise | from .elemwise import Elemwise | ||||
| from .linear import Linear | from .linear import Linear | ||||
| from .module import QATModule | from .module import QATModule | ||||
| @@ -59,8 +59,8 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule): | |||||
| def calc_conv_transpose2d_qat(self, inp): | def calc_conv_transpose2d_qat(self, inp): | ||||
| w_qat = self.apply_quant_weight(self.weight) | w_qat = self.apply_quant_weight(self.weight) | ||||
| b_qat = self.apply_quant_bias(self.bias, inp, w_qat) | b_qat = self.apply_quant_bias(self.bias, inp, w_qat) | ||||
| conv = self.calc_conv_transpose2d(inp, w_qat, b_qat) | |||||
| return conv | |||||
| conv_transpose2d = self.calc_conv_transpose2d(inp, w_qat, b_qat) | |||||
| return conv_transpose2d | |||||
| @classmethod | @classmethod | ||||
| def from_float_module(cls, float_module: Float.ConvTranspose2d): | def from_float_module(cls, float_module: Float.ConvTranspose2d): | ||||
| @@ -88,3 +88,12 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule): | |||||
| def forward(self, inp): | def forward(self, inp): | ||||
| return self.apply_quant_activation(self.calc_conv_transpose2d_qat(inp)) | return self.apply_quant_activation(self.calc_conv_transpose2d_qat(inp)) | ||||
| class ConvTransposeRelu2d(ConvTranspose2d): | |||||
| r"""A :class:`~.QATModule` include :class:`~.module.ConvTranspose2d` and :func:`~.relu` with QAT support. | |||||
| Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`. | |||||
| """ | |||||
| def forward(self, inp): | |||||
| return self.apply_quant_activation(F.relu(self.calc_conv_transpose2d_qat(inp))) | |||||
| @@ -0,0 +1,163 @@ | |||||
| from ...functional import ones, relu, sqrt, sum, zeros | |||||
| from .. import conv_transpose_bn as Float | |||||
| from .module import QATModule | |||||
| class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule): | |||||
| def get_batch_mean_var(self, inp): | |||||
| def _sum_channel(inp, axis=0, keepdims=True): | |||||
| if isinstance(axis, int): | |||||
| out = sum(inp, axis=axis, keepdims=keepdims) | |||||
| elif isinstance(axis, tuple): | |||||
| for idx, elem in enumerate(axis): | |||||
| out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) | |||||
| return out | |||||
| sum1 = _sum_channel(inp, (0, 2, 3)) | |||||
| sum2 = _sum_channel(inp ** 2, (0, 2, 3)) | |||||
| reduce_size = inp.size / inp.shape[1] | |||||
| batch_mean = sum1 / reduce_size | |||||
| batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size | |||||
| return batch_mean, batch_var | |||||
| def fold_weight_bias(self, bn_mean, bn_var): | |||||
| # get fold bn conv_transpose2d param | |||||
| gamma = self.bn.weight | |||||
| if gamma is None: | |||||
| gamma = ones((self.bn.num_features), dtype="float32") | |||||
| gamma = gamma.reshape(1, -1, 1, 1) | |||||
| beta = self.bn.bias | |||||
| if beta is None: | |||||
| beta = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||||
| if bn_mean is None: | |||||
| bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||||
| if bn_var is None: | |||||
| bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||||
| conv_transpose2d_bias = self.conv_transpose2d.bias | |||||
| if conv_transpose2d_bias is None: | |||||
| conv_transpose2d_bias = zeros( | |||||
| self.conv_transpose2d._infer_bias_shape(), dtype="float32" | |||||
| ) | |||||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||||
| scale_factor = gamma * bn_istd | |||||
| if self.conv_transpose2d.groups == 1: | |||||
| w_fold = self.conv_transpose2d.weight * scale_factor.reshape(-1, 1, 1, 1) | |||||
| else: | |||||
| w_fold = self.conv_transpose2d.weight * scale_factor.reshape( | |||||
| self.conv_transpose2d.groups, -1, 1, 1, 1 | |||||
| ) | |||||
| w_fold = self.apply_quant_weight(w_fold) | |||||
| b_fold = beta + gamma * (conv_transpose2d_bias - bn_mean) * bn_istd | |||||
| return w_fold, b_fold | |||||
| def update_running_mean_and_running_var( | |||||
| self, bn_mean, bn_var, num_elements_per_channel | |||||
| ): | |||||
| # update running mean and running var. no grad, use unbiased bn var | |||||
| bn_mean = bn_mean.detach() | |||||
| bn_var = ( | |||||
| bn_var.detach() * num_elements_per_channel / (num_elements_per_channel - 1) | |||||
| ) | |||||
| exponential_average_factor = 1 - self.bn.momentum | |||||
| self.bn.running_mean *= self.bn.momentum | |||||
| self.bn.running_mean += exponential_average_factor * bn_mean | |||||
| self.bn.running_var *= self.bn.momentum | |||||
| self.bn.running_var += exponential_average_factor * bn_var | |||||
| def calc_conv_transpose2d_bn_qat(self, inp, approx=True): | |||||
| if self.training and not approx: | |||||
| conv_transpose2d = self.conv_transpose2d(inp) | |||||
| bn_mean, bn_var = self.get_batch_mean_var(conv_transpose2d) | |||||
| num_elements_per_channel = conv_transpose2d.size / conv_transpose2d.shape[1] | |||||
| self.update_running_mean_and_running_var( | |||||
| bn_mean, bn_var, num_elements_per_channel | |||||
| ) | |||||
| else: | |||||
| bn_mean, bn_var = self.bn.running_mean, self.bn.running_var | |||||
| # get gamma and beta in BatchNorm | |||||
| gamma = self.bn.weight | |||||
| if gamma is None: | |||||
| gamma = ones((self.bn.num_features), dtype="float32") | |||||
| gamma = gamma.reshape(1, -1, 1, 1) | |||||
| beta = self.bn.bias | |||||
| if beta is None: | |||||
| beta = zeros((self.bn.num_features), dtype="float32") | |||||
| beta = beta.reshape(1, -1, 1, 1) | |||||
| # conv_transpose2d_bias | |||||
| conv_transpose2d_bias = self.conv_transpose2d.bias | |||||
| if conv_transpose2d_bias is None: | |||||
| conv_transpose2d_bias = zeros( | |||||
| self.conv_transpose2d._infer_bias_shape(), dtype="float32" | |||||
| ) | |||||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||||
| scale_factor = gamma * bn_istd | |||||
| if self.conv_transpose2d.groups == 1: | |||||
| w_fold = self.conv_transpose2d.weight * scale_factor.reshape(1, -1, 1, 1) | |||||
| else: | |||||
| w_fold = self.conv_transpose2d.weight * scale_factor.reshape( | |||||
| self.conv_transpose2d.groups, 1, -1, 1, 1 | |||||
| ) | |||||
| b_fold = None | |||||
| if not (self.training and approx): | |||||
| b_fold = beta + gamma * (conv_transpose2d_bias - bn_mean) * bn_istd | |||||
| w_qat = self.apply_quant_weight(w_fold) | |||||
| b_qat = self.apply_quant_bias(b_fold, inp, w_qat) | |||||
| conv_transpose2d = self.conv_transpose2d.calc_conv_transpose2d( | |||||
| inp, w_qat, b_qat | |||||
| ) | |||||
| if not (self.training and approx): | |||||
| return conv_transpose2d | |||||
| # rescale conv_transpose2d to get original conv_transpose2d output | |||||
| orig_conv_transpose2d = conv_transpose2d / scale_factor.reshape(1, -1, 1, 1) | |||||
| if self.conv_transpose2d.bias is not None: | |||||
| orig_conv_transpose2d = orig_conv_transpose2d + self.conv_transpose2d.bias | |||||
| # calculate batch norm | |||||
| conv_transpose2d = self.bn(orig_conv_transpose2d) | |||||
| return conv_transpose2d | |||||
| @classmethod | |||||
| def from_float_module(cls, float_module: Float._ConvTransposeBnActivation2d): | |||||
| qat_module = cls( | |||||
| float_module.conv_transpose2d.in_channels, | |||||
| float_module.conv_transpose2d.out_channels, | |||||
| float_module.conv_transpose2d.kernel_size, | |||||
| float_module.conv_transpose2d.stride, | |||||
| float_module.conv_transpose2d.padding, | |||||
| float_module.conv_transpose2d.output_padding, | |||||
| float_module.conv_transpose2d.dilation, | |||||
| float_module.conv_transpose2d.groups, | |||||
| float_module.conv_transpose2d.bias is not None, | |||||
| float_module.conv_transpose2d.conv_mode, | |||||
| float_module.conv_transpose2d.compute_mode, | |||||
| name=float_module.name, | |||||
| ) | |||||
| qat_module.conv_transpose2d.weight = float_module.conv_transpose2d.weight | |||||
| qat_module.conv_transpose2d.bias = float_module.conv_transpose2d.bias | |||||
| qat_module.bn = float_module.bn | |||||
| return qat_module | |||||
| class ConvTransposeBn2d(_ConvTransposeBnActivation2d): | |||||
| r"""A fused :class:`~.QATModule` including :class:`~.module.ConvTranspose2d` and :class:`~.module.BatchNorm2d` with QAT support. | |||||
| Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`. | |||||
| """ | |||||
| def forward(self, inp): | |||||
| return self.apply_quant_activation(self.calc_conv_transpose2d_bn_qat(inp)) | |||||
| class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d): | |||||
| r"""A fused :class:`~.QATModule` including :class:`~.module.ConvTranspose2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu` with QAT support. | |||||
| Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`. | |||||
| """ | |||||
| def forward(self, inp): | |||||
| return self.apply_quant_activation(relu(self.calc_conv_transpose2d_bn_qat(inp))) | |||||
| @@ -1,7 +1,8 @@ | |||||
| from .batch_matmul_activation import BatchMatMulActivation | from .batch_matmul_activation import BatchMatMulActivation | ||||
| from .concat import Concat | from .concat import Concat | ||||
| from .conv import Conv2d, ConvRelu2d, ConvTranspose2d | |||||
| from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d | |||||
| from .conv_bn import ConvBn2d, ConvBnRelu2d | from .conv_bn import ConvBn2d, ConvBnRelu2d | ||||
| from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d | |||||
| from .elemwise import Elemwise | from .elemwise import Elemwise | ||||
| from .linear import Linear | from .linear import Linear | ||||
| from .module import QuantizedModule | from .module import QuantizedModule | ||||
| @@ -178,7 +178,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): | |||||
| :class:`~.QATModule` instance. | :class:`~.QATModule` instance. | ||||
| """ | """ | ||||
| output_dtype = qat_module.get_activation_dtype() | output_dtype = qat_module.get_activation_dtype() | ||||
| qconv = cls( | |||||
| qconv_transpose2d = cls( | |||||
| qat_module.in_channels, | qat_module.in_channels, | ||||
| qat_module.out_channels, | qat_module.out_channels, | ||||
| qat_module.kernel_size, | qat_module.kernel_size, | ||||
| @@ -194,15 +194,19 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): | |||||
| name=qat_module.name, | name=qat_module.name, | ||||
| ) | ) | ||||
| weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | ||||
| qconv.weight = Parameter(weight.numpy(), name=qat_module.weight.name) | |||||
| qconv.bias = ( | |||||
| qconv_transpose2d.weight = Parameter( | |||||
| weight.numpy(), name=qat_module.weight.name | |||||
| ) | |||||
| qconv_transpose2d.bias = ( | |||||
| Parameter(qat_module.bias.numpy(), name=qat_module.bias.name) | Parameter(qat_module.bias.numpy(), name=qat_module.bias.name) | ||||
| if qat_module.bias is not None | if qat_module.bias is not None | ||||
| else None | else None | ||||
| ) | ) | ||||
| return qconv | |||||
| return qconv_transpose2d | |||||
| def calc_conv_transpose2d_quantized(self, inp, nonlinear_mode): | |||||
| assert nonlinear_mode == "identity", "nonlinear_mode shoule be 'identity'" | |||||
| def calc_conv_transpose2d_quantized(self, inp): | |||||
| if self.bias is not None: | if self.bias is not None: | ||||
| inp_scale = dtype.get_scale(inp.dtype) | inp_scale = dtype.get_scale(inp.dtype) | ||||
| w_scale = dtype.get_scale(self.weight.dtype) | w_scale = dtype.get_scale(self.weight.dtype) | ||||
| @@ -225,4 +229,11 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): | |||||
| ) | ) | ||||
| def forward(self, inp): | def forward(self, inp): | ||||
| return self.calc_conv_transpose2d_quantized(inp) | |||||
| return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="identity") | |||||
| class ConvTransposeRelu2d(ConvTranspose2d): | |||||
| r"""Quantized version of :class:`~.qat.ConvTransposeRelu2d`.""" | |||||
| def forward(self, inp): | |||||
| return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="relu") | |||||
| @@ -0,0 +1,53 @@ | |||||
| from ...tensor import Parameter | |||||
| from ..qat import conv_transpose_bn as QAT | |||||
| from .conv import ConvTranspose2d | |||||
| class _ConvTransposeBnActivation2d(ConvTranspose2d): | |||||
| r"""Applies a 2D deconvolution over a quantized input tensor, used for inference only. | |||||
| """ | |||||
| @classmethod | |||||
| def from_qat_module(cls, qat_module: QAT._ConvTransposeBnActivation2d): | |||||
| r""" | |||||
| Return a :class:`~.QuantizedModule` instance converted from a | |||||
| :class:`~.QATModule` instance. | |||||
| """ | |||||
| output_dtype = qat_module.get_activation_dtype() | |||||
| qconv_transpose2d = cls( | |||||
| qat_module.conv_transpose2d.in_channels, | |||||
| qat_module.conv_transpose2d.out_channels, | |||||
| qat_module.conv_transpose2d.kernel_size, | |||||
| qat_module.conv_transpose2d.stride, | |||||
| qat_module.conv_transpose2d.padding, | |||||
| qat_module.conv_transpose2d.output_padding, | |||||
| qat_module.conv_transpose2d.dilation, | |||||
| qat_module.conv_transpose2d.groups, | |||||
| dtype=output_dtype, | |||||
| name=qat_module.name, | |||||
| ) | |||||
| w_fold, b_fold = qat_module.fold_weight_bias( | |||||
| qat_module.bn.running_mean, qat_module.bn.running_var | |||||
| ) | |||||
| weight = w_fold.astype(qat_module.get_weight_dtype()) | |||||
| qconv_transpose2d.weight = Parameter( | |||||
| weight.numpy(), name=qat_module.conv_transpose2d.weight.name | |||||
| ) | |||||
| qconv_transpose2d.bias = Parameter(b_fold.numpy()) | |||||
| if qat_module.conv_transpose2d.bias is not None: | |||||
| qconv_transpose2d.bias.name = qat_module.conv_transpose2d.bias.name | |||||
| return qconv_transpose2d | |||||
| class ConvTransposeBn2d(_ConvTransposeBnActivation2d): | |||||
| r"""Quantized version of :class:`~.qat.ConvTransposeBn2d`.""" | |||||
| def forward(self, inp): | |||||
| return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="identity") | |||||
| class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d): | |||||
| r"""Quantized version of :class:`~.qat.ConvTransposeBnRelu2d`.""" | |||||
| def forward(self, inp): | |||||
| return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="relu") | |||||
| @@ -1,48 +1,70 @@ | |||||
| from copy import deepcopy | from copy import deepcopy | ||||
| from ..functional import ones, sqrt, zeros | from ..functional import ones, sqrt, zeros | ||||
| from ..module import BatchNorm2d, Conv2d, ConvBn2d, ConvBnRelu2d, ConvRelu2d, ReLU | |||||
| from ..module import ( | |||||
| BatchNorm2d, | |||||
| Conv2d, | |||||
| ConvBn2d, | |||||
| ConvBnRelu2d, | |||||
| ConvRelu2d, | |||||
| ConvTranspose2d, | |||||
| ConvTransposeBn2d, | |||||
| ConvTransposeBnRelu2d, | |||||
| ConvTransposeRelu2d, | |||||
| ReLU, | |||||
| ) | |||||
| from ..tensor import Parameter | from ..tensor import Parameter | ||||
| _MAP_TO_FUSED_MODULE = { | _MAP_TO_FUSED_MODULE = { | ||||
| (Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d, | (Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d, | ||||
| (Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d, | (Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d, | ||||
| (ConvTranspose2d, BatchNorm2d, ReLU, False): ConvTransposeRelu2d, | |||||
| (ConvTranspose2d, BatchNorm2d, ReLU, True): ConvTransposeBnRelu2d, | |||||
| (Conv2d, BatchNorm2d, False): Conv2d, | (Conv2d, BatchNorm2d, False): Conv2d, | ||||
| (Conv2d, BatchNorm2d, True): ConvBn2d, | (Conv2d, BatchNorm2d, True): ConvBn2d, | ||||
| (Conv2d, ReLU): ConvRelu2d, | (Conv2d, ReLU): ConvRelu2d, | ||||
| (ConvTranspose2d, BatchNorm2d, False): ConvTranspose2d, | |||||
| (ConvTranspose2d, BatchNorm2d, True): ConvTransposeBn2d, | |||||
| (ConvTranspose2d, ReLU): ConvTransposeRelu2d, | |||||
| } | } | ||||
| def fold_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5): | |||||
| # get fold bn conv param | |||||
| def fold_weight_bias( | |||||
| weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False | |||||
| ): | |||||
| shape = (1, -1, 1, 1) | |||||
| if transpose: | |||||
| shape = (-1, 1, 1, 1) | |||||
| kernel_shape = weight.shape | kernel_shape = weight.shape | ||||
| if len(kernel_shape) == 5: | if len(kernel_shape) == 5: | ||||
| groups, num_features = kernel_shape[0], kernel_shape[1] | groups, num_features = kernel_shape[0], kernel_shape[1] | ||||
| else: | else: | ||||
| groups, num_features = 1, kernel_shape[0] | groups, num_features = 1, kernel_shape[0] | ||||
| out_channels = groups * num_features | |||||
| if gamma is None: | if gamma is None: | ||||
| gamma = ones((num_features), dtype="float32") | |||||
| gamma = ones((out_channels,), dtype="float32") | |||||
| gamma = gamma.reshape(1, -1, 1, 1) | gamma = gamma.reshape(1, -1, 1, 1) | ||||
| if beta is None: | if beta is None: | ||||
| beta = zeros((num_features), dtype="float32") | |||||
| beta = zeros((out_channels,), dtype="float32") | |||||
| beta = beta.reshape(1, -1, 1, 1) | beta = beta.reshape(1, -1, 1, 1) | ||||
| if bn_mean is None: | if bn_mean is None: | ||||
| bn_mean = zeros((1, num_features, 1, 1), dtype="float32") | |||||
| bn_mean = zeros((1, out_channels, 1, 1), dtype="float32") | |||||
| if bn_var is None: | if bn_var is None: | ||||
| bn_var = ones((1, num_features, 1, 1), dtype="float32") | |||||
| bn_var = ones((1, out_channels, 1, 1), dtype="float32") | |||||
| if bias is None: | if bias is None: | ||||
| bias = zeros((1, num_features, 1, 1), dtype="float32") | |||||
| bias = zeros((1, out_channels, 1, 1), dtype="float32") | |||||
| bn_istd = 1.0 / sqrt(bn_var + eps) | bn_istd = 1.0 / sqrt(bn_var + eps) | ||||
| scale_factor = gamma * bn_istd | scale_factor = gamma * bn_istd | ||||
| if groups == 1: | if groups == 1: | ||||
| w_fold = weight * scale_factor.reshape(-1, 1, 1, 1) | |||||
| w_fold = weight * scale_factor.reshape(*shape) | |||||
| else: | else: | ||||
| w_fold = weight * scale_factor.reshape(groups, -1, 1, 1, 1) | |||||
| w_fold = weight * scale_factor.reshape(groups, *shape) | |||||
| b_fold = beta + gamma * (bias - bn_mean) * bn_istd | b_fold = beta + gamma * (bias - bn_mean) * bn_istd | ||||
| return w_fold, b_fold | return w_fold, b_fold | ||||
| @@ -84,3 +106,55 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): | |||||
| module.bn = deepcopy(bn) | module.bn = deepcopy(bn) | ||||
| new_conv.training = conv.training | new_conv.training = conv.training | ||||
| return module | return module | ||||
| def fuse_conv_transpose2d_bn_relu_module( | |||||
| conv_transpose2d: ConvTranspose2d, bn: BatchNorm2d, relu: ReLU | |||||
| ): | |||||
| module_key = tuple([type(m) for m in [conv_transpose2d, bn, relu] if m]) | |||||
| if bn: | |||||
| assert ( | |||||
| conv_transpose2d.training == bn.training | |||||
| ), "ConvTranspose2d and BN both must be in the same mode (train or eval)." | |||||
| assert ( | |||||
| bn.num_features == conv_transpose2d.out_channels | |||||
| ), "Output channel of ConvTranspose2d must match num_features of BatchNorm2d" | |||||
| module_key = module_key + (conv_transpose2d.training,) | |||||
| module = _MAP_TO_FUSED_MODULE[module_key]( | |||||
| in_channels=conv_transpose2d.in_channels, | |||||
| out_channels=conv_transpose2d.out_channels, | |||||
| kernel_size=conv_transpose2d.kernel_size, | |||||
| stride=conv_transpose2d.stride, | |||||
| padding=conv_transpose2d.padding, | |||||
| output_padding=conv_transpose2d.output_padding, | |||||
| dilation=conv_transpose2d.dilation, | |||||
| groups=conv_transpose2d.groups, | |||||
| bias=conv_transpose2d.bias is not None, | |||||
| conv_mode=conv_transpose2d.conv_mode, | |||||
| compute_mode=conv_transpose2d.compute_mode, | |||||
| name=conv_transpose2d.name, | |||||
| ) | |||||
| new_conv_transpose2d = ( | |||||
| module | |||||
| if bn is None or not conv_transpose2d.training | |||||
| else module.conv_transpose2d | |||||
| ) | |||||
| weight, bias = conv_transpose2d.weight, conv_transpose2d.bias | |||||
| if not conv_transpose2d.training and bn is not None: | |||||
| weight, bias = fold_weight_bias( | |||||
| weight, | |||||
| bias, | |||||
| bn.weight, | |||||
| bn.bias, | |||||
| bn.running_mean, | |||||
| bn.running_var, | |||||
| bn.eps, | |||||
| transpose=False, | |||||
| ) | |||||
| new_conv_transpose2d.weight = Parameter(weight) | |||||
| if bias is not None: | |||||
| new_conv_transpose2d.bias = Parameter(bias) | |||||
| if bn is not None and conv_transpose2d.training: | |||||
| module.bn = deepcopy(bn) | |||||
| new_conv_transpose2d.training = conv_transpose2d.training | |||||
| return module | |||||
| @@ -5,7 +5,9 @@ import numpy as np | |||||
| import pytest | import pytest | ||||
| import megengine.utils.comp_graph_tools as cgtools | import megengine.utils.comp_graph_tools as cgtools | ||||
| from megengine import jit, tensor | |||||
| from megengine import jit | |||||
| from megengine import module as M | |||||
| from megengine import tensor | |||||
| from megengine.device import get_device_count | from megengine.device import get_device_count | ||||
| from megengine.functional import expand_dims | from megengine.functional import expand_dims | ||||
| from megengine.module import ( | from megengine.module import ( | ||||
| @@ -14,6 +16,8 @@ from megengine.module import ( | |||||
| ConvBn2d, | ConvBn2d, | ||||
| ConvRelu2d, | ConvRelu2d, | ||||
| ConvTranspose2d, | ConvTranspose2d, | ||||
| ConvTransposeBn2d, | |||||
| ConvTransposeRelu2d, | |||||
| DequantStub, | DequantStub, | ||||
| Module, | Module, | ||||
| QuantStub, | QuantStub, | ||||
| @@ -34,6 +38,49 @@ def test_qat_convbn2d(): | |||||
| module = ConvBn2d( | module = ConvBn2d( | ||||
| in_channels, out_channels, kernel_size, groups=groups, bias=bias | in_channels, out_channels, kernel_size, groups=groups, bias=bias | ||||
| ) | ) | ||||
| M.init.normal_(module.bn.weight) | |||||
| M.init.normal_(module.bn.bias) | |||||
| module.train() | |||||
| qat_module = quantize_qat(module, inplace=False) | |||||
| disable_fake_quant(qat_module) | |||||
| inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
| normal_outputs = module(inputs) | |||||
| qat_outputs = qat_module(inputs) | |||||
| np.testing.assert_allclose( | |||||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||||
| ) | |||||
| np.testing.assert_allclose( | |||||
| module.bn.running_mean.numpy(), | |||||
| qat_module.bn.running_mean.numpy(), | |||||
| atol=5e-8, | |||||
| ) | |||||
| np.testing.assert_allclose( | |||||
| module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7, | |||||
| ) | |||||
| module.eval() | |||||
| normal_outputs = module(inputs) | |||||
| qat_module.eval() | |||||
| qat_outputs = qat_module(inputs) | |||||
| np.testing.assert_allclose( | |||||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||||
| ) | |||||
| def test_qat_convtransposebn2d(): | |||||
| in_channels = 32 | |||||
| out_channels = 64 | |||||
| kernel_size = 3 | |||||
| for groups, bias in product([1, 4], [True, False]): | |||||
| module = ConvTransposeBn2d( | |||||
| in_channels=in_channels, | |||||
| out_channels=out_channels, | |||||
| kernel_size=kernel_size, | |||||
| output_padding=0, | |||||
| groups=groups, | |||||
| bias=bias, | |||||
| ) | |||||
| M.init.normal_(module.bn.weight) | |||||
| M.init.normal_(module.bn.bias) | |||||
| module.train() | module.train() | ||||
| qat_module = quantize_qat(module, inplace=False) | qat_module = quantize_qat(module, inplace=False) | ||||
| disable_fake_quant(qat_module) | disable_fake_quant(qat_module) | ||||
| @@ -235,10 +282,14 @@ def test_qat_conv_transpose2d(): | |||||
| self.conv = ConvTranspose2d( | self.conv = ConvTranspose2d( | ||||
| in_channels, out_channels, kernel_size, bias=bias | in_channels, out_channels, kernel_size, bias=bias | ||||
| ) | ) | ||||
| self.conv_transpose2d_relu = ConvTransposeRelu2d( | |||||
| out_channels, in_channels, kernel_size, bias=bias | |||||
| ) | |||||
| def forward(self, inp): | def forward(self, inp): | ||||
| out = self.quant(inp) | out = self.quant(inp) | ||||
| out = self.conv(out) | out = self.conv(out) | ||||
| out = self.conv_transpose2d_relu(out) | |||||
| out = self.dequant(out) | out = self.dequant(out) | ||||
| return out | return out | ||||
| @@ -250,10 +301,14 @@ def test_qat_conv_transpose2d(): | |||||
| disable_fake_quant(qat_net) | disable_fake_quant(qat_net) | ||||
| normal_outputs = net(inputs) | normal_outputs = net(inputs) | ||||
| qat_outputs = qat_net(inputs) | qat_outputs = qat_net(inputs) | ||||
| np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||||
| np.testing.assert_allclose( | |||||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6 | |||||
| ) | |||||
| net.eval() | net.eval() | ||||
| normal_outputs = net(inputs) | normal_outputs = net(inputs) | ||||
| qat_net.eval() | qat_net.eval() | ||||
| qat_outputs = qat_net(inputs) | qat_outputs = qat_net(inputs) | ||||
| np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||||
| np.testing.assert_allclose( | |||||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6 | |||||
| ) | |||||