| @@ -24,8 +24,7 @@ class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule | |||
| # get fold bn conv_transpose2d param | |||
| gamma = self.bn.weight | |||
| if gamma is None: | |||
| gamma = ones((self.bn.num_features), dtype="float32") | |||
| gamma = gamma.reshape(1, -1, 1, 1) | |||
| gamma = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||
| beta = self.bn.bias | |||
| if beta is None: | |||
| beta = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||
| @@ -44,10 +43,10 @@ class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule | |||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
| scale_factor = gamma * bn_istd | |||
| if self.conv_transpose2d.groups == 1: | |||
| w_fold = self.conv_transpose2d.weight * scale_factor.reshape(-1, 1, 1, 1) | |||
| w_fold = self.conv_transpose2d.weight * scale_factor.reshape(1, -1, 1, 1) | |||
| else: | |||
| w_fold = self.conv_transpose2d.weight * scale_factor.reshape( | |||
| self.conv_transpose2d.groups, -1, 1, 1, 1 | |||
| self.conv_transpose2d.groups, 1, -1, 1, 1 | |||
| ) | |||
| w_fold = self.apply_quant_weight(w_fold) | |||
| @@ -32,15 +32,21 @@ _MAP_TO_FUSED_MODULE = { | |||
| def fold_weight_bias( | |||
| weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False | |||
| ): | |||
| shape = (1, -1, 1, 1) | |||
| shape = (-1, 1, 1, 1) | |||
| if transpose: | |||
| shape = (-1, 1, 1, 1) | |||
| shape = (1, -1, 1, 1) | |||
| kernel_shape = weight.shape | |||
| if len(kernel_shape) == 5: | |||
| groups, num_features = kernel_shape[0], kernel_shape[1] | |||
| if transpose: | |||
| groups, num_features = kernel_shape[0], kernel_shape[2] | |||
| else: | |||
| groups, num_features = kernel_shape[0], kernel_shape[1] | |||
| else: | |||
| groups, num_features = 1, kernel_shape[0] | |||
| if transpose: | |||
| groups, num_features = 1, kernel_shape[1] | |||
| else: | |||
| groups, num_features = 1, kernel_shape[0] | |||
| out_channels = groups * num_features | |||
| if gamma is None: | |||
| @@ -93,12 +99,37 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): | |||
| compute_mode=conv.compute_mode, | |||
| name=conv.name, | |||
| ) | |||
| new_conv = module if bn is None or not conv.training else module.conv | |||
| if isinstance(conv, ConvTranspose2d): | |||
| module.output_padding = conv.output_padding | |||
| new_conv = ( | |||
| module if bn is None or not conv.training else module.conv_transpose2d | |||
| ) | |||
| else: | |||
| new_conv = module if bn is None or not conv.training else module.conv | |||
| weight, bias = conv.weight, conv.bias | |||
| if not conv.training and bn is not None: | |||
| weight, bias = fold_weight_bias( | |||
| weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps, | |||
| ) | |||
| if isinstance(conv, ConvTranspose2d): | |||
| weight, bias = fold_weight_bias( | |||
| weight, | |||
| bias, | |||
| bn.weight, | |||
| bn.bias, | |||
| bn.running_mean, | |||
| bn.running_var, | |||
| bn.eps, | |||
| transpose=True, | |||
| ) | |||
| else: | |||
| weight, bias = fold_weight_bias( | |||
| weight, | |||
| bias, | |||
| bn.weight, | |||
| bn.bias, | |||
| bn.running_mean, | |||
| bn.running_var, | |||
| bn.eps, | |||
| ) | |||
| new_conv.weight = Parameter(weight) | |||
| if bias is not None: | |||
| new_conv.bias = Parameter(bias) | |||
| @@ -106,55 +137,3 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): | |||
| module.bn = deepcopy(bn) | |||
| new_conv.training = conv.training | |||
| return module | |||
| def fuse_conv_transpose2d_bn_relu_module( | |||
| conv_transpose2d: ConvTranspose2d, bn: BatchNorm2d, relu: ReLU | |||
| ): | |||
| module_key = tuple([type(m) for m in [conv_transpose2d, bn, relu] if m]) | |||
| if bn: | |||
| assert ( | |||
| conv_transpose2d.training == bn.training | |||
| ), "ConvTranspose2d and BN both must be in the same mode (train or eval)." | |||
| assert ( | |||
| bn.num_features == conv_transpose2d.out_channels | |||
| ), "Output channel of ConvTranspose2d must match num_features of BatchNorm2d" | |||
| module_key = module_key + (conv_transpose2d.training,) | |||
| module = _MAP_TO_FUSED_MODULE[module_key]( | |||
| in_channels=conv_transpose2d.in_channels, | |||
| out_channels=conv_transpose2d.out_channels, | |||
| kernel_size=conv_transpose2d.kernel_size, | |||
| stride=conv_transpose2d.stride, | |||
| padding=conv_transpose2d.padding, | |||
| output_padding=conv_transpose2d.output_padding, | |||
| dilation=conv_transpose2d.dilation, | |||
| groups=conv_transpose2d.groups, | |||
| bias=conv_transpose2d.bias is not None, | |||
| conv_mode=conv_transpose2d.conv_mode, | |||
| compute_mode=conv_transpose2d.compute_mode, | |||
| name=conv_transpose2d.name, | |||
| ) | |||
| new_conv_transpose2d = ( | |||
| module | |||
| if bn is None or not conv_transpose2d.training | |||
| else module.conv_transpose2d | |||
| ) | |||
| weight, bias = conv_transpose2d.weight, conv_transpose2d.bias | |||
| if not conv_transpose2d.training and bn is not None: | |||
| weight, bias = fold_weight_bias( | |||
| weight, | |||
| bias, | |||
| bn.weight, | |||
| bn.bias, | |||
| bn.running_mean, | |||
| bn.running_var, | |||
| bn.eps, | |||
| transpose=False, | |||
| ) | |||
| new_conv_transpose2d.weight = Parameter(weight) | |||
| if bias is not None: | |||
| new_conv_transpose2d.bias = Parameter(bias) | |||
| if bn is not None and conv_transpose2d.training: | |||
| module.bn = deepcopy(bn) | |||
| new_conv_transpose2d.training = conv_transpose2d.training | |||
| return module | |||
| @@ -34,35 +34,49 @@ def test_qat_convbn2d(): | |||
| in_channels = 32 | |||
| out_channels = 64 | |||
| kernel_size = 3 | |||
| class TestNet(Module): | |||
| def __init__(self, groups, bias): | |||
| super().__init__() | |||
| self.quant = QuantStub() | |||
| self.dequant = DequantStub() | |||
| self.conv_bn = ConvBn2d( | |||
| in_channels, out_channels, kernel_size, groups=groups, bias=bias, | |||
| ) | |||
| def forward(self, inp): | |||
| out = self.quant(inp) | |||
| out = self.conv_bn(out) | |||
| out = self.dequant(out) | |||
| return out | |||
| inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
| for groups, bias in product([1, 4], [True, False]): | |||
| module = ConvBn2d( | |||
| in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||
| ) | |||
| M.init.normal_(module.bn.weight) | |||
| M.init.normal_(module.bn.bias) | |||
| module.train() | |||
| qat_module = quantize_qat(module, inplace=False) | |||
| disable_fake_quant(qat_module) | |||
| inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
| normal_outputs = module(inputs) | |||
| qat_outputs = qat_module(inputs) | |||
| net = TestNet(groups, bias) | |||
| net.train() | |||
| qat_net = quantize_qat(net, inplace=False) | |||
| disable_fake_quant(qat_net) | |||
| normal_outputs = net(inputs) | |||
| qat_outputs = qat_net(inputs) | |||
| np.testing.assert_allclose( | |||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-4, | |||
| ) | |||
| np.testing.assert_allclose( | |||
| module.bn.running_mean.numpy(), | |||
| qat_module.bn.running_mean.numpy(), | |||
| net.conv_bn.bn.running_mean.numpy(), | |||
| qat_net.conv_bn.bn.running_mean.numpy(), | |||
| atol=5e-8, | |||
| ) | |||
| np.testing.assert_allclose( | |||
| module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7, | |||
| net.conv_bn.bn.running_var.numpy(), | |||
| qat_net.conv_bn.bn.running_var.numpy(), | |||
| atol=5e-7, | |||
| ) | |||
| module.eval() | |||
| normal_outputs = module(inputs) | |||
| qat_module.eval() | |||
| qat_outputs = qat_module(inputs) | |||
| net.eval() | |||
| normal_outputs = net(inputs) | |||
| qat_net.eval() | |||
| qat_outputs = qat_net(inputs) | |||
| np.testing.assert_allclose( | |||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-4, | |||
| ) | |||
| @@ -70,40 +84,44 @@ def test_qat_convtransposebn2d(): | |||
| in_channels = 32 | |||
| out_channels = 64 | |||
| kernel_size = 3 | |||
| class TestNet(Module): | |||
| def __init__(self, groups, bias): | |||
| super().__init__() | |||
| self.quant = QuantStub() | |||
| self.dequant = DequantStub() | |||
| self.conv_transpose_bn = ConvTransposeBn2d( | |||
| in_channels, out_channels, kernel_size, groups=groups, bias=bias, | |||
| ) | |||
| def forward(self, inp): | |||
| out = self.quant(inp) | |||
| out = self.conv_transpose_bn(out) | |||
| out = self.dequant(out) | |||
| return out | |||
| for groups, bias in product([1, 4], [True, False]): | |||
| module = ConvTransposeBn2d( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| kernel_size=kernel_size, | |||
| output_padding=0, | |||
| groups=groups, | |||
| bias=bias, | |||
| ) | |||
| M.init.normal_(module.bn.weight) | |||
| M.init.normal_(module.bn.bias) | |||
| module.train() | |||
| qat_module = quantize_qat(module, inplace=False) | |||
| disable_fake_quant(qat_module) | |||
| net = TestNet(groups, bias) | |||
| net.train() | |||
| qat_net = quantize_qat(net, inplace=False) | |||
| disable_fake_quant(qat_net) | |||
| inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
| normal_outputs = module(inputs) | |||
| qat_outputs = qat_module(inputs) | |||
| np.testing.assert_allclose( | |||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||
| ) | |||
| normal_outputs = net(inputs) | |||
| qat_outputs = qat_net(inputs) | |||
| np.testing.assert_allclose( | |||
| module.bn.running_mean.numpy(), | |||
| qat_module.bn.running_mean.numpy(), | |||
| atol=5e-8, | |||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-5, | |||
| ) | |||
| np.testing.assert_allclose( | |||
| module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7, | |||
| net.conv_transpose_bn.bn.running_var.numpy(), | |||
| qat_net.conv_transpose_bn.bn.running_var.numpy(), | |||
| atol=5e-7, | |||
| ) | |||
| module.eval() | |||
| normal_outputs = module(inputs) | |||
| qat_module.eval() | |||
| qat_outputs = qat_module(inputs) | |||
| net.eval() | |||
| normal_outputs = net(inputs) | |||
| qat_net.eval() | |||
| qat_outputs = qat_net(inputs) | |||
| np.testing.assert_allclose( | |||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-5, | |||
| ) | |||
| @@ -3,6 +3,15 @@ import pytest | |||
| from megengine import Parameter, Tensor | |||
| from megengine import module as Float | |||
| from megengine.functional import ones, zeros | |||
| from megengine.module import ( | |||
| BatchNorm2d, | |||
| Conv2d, | |||
| ConvBn2d, | |||
| ConvTranspose2d, | |||
| ConvTransposeBn2d, | |||
| ReLU, | |||
| ) | |||
| from megengine.module import qat as QAT | |||
| from megengine.module import quantized as Q | |||
| from megengine.quantization import ( | |||
| @@ -24,6 +33,7 @@ from megengine.quantization.quantize import ( | |||
| quantize_qat, | |||
| reset_qconfig, | |||
| ) | |||
| from megengine.utils.bn_fusion import fuse_conv_bn_relu_module | |||
| class FloatNet(Float.Module): | |||
| @@ -291,3 +301,85 @@ def test_convert_with_custom_mapping(): | |||
| net = Net() | |||
| qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) | |||
| assert isinstance(qat_net.example, QATExample) | |||
| def test_ConvBn2d_fold_weight_bias(): | |||
| in_channels = 32 | |||
| out_channels = 64 | |||
| kernel_size = 3 | |||
| conv = Conv2d(in_channels, out_channels, kernel_size) | |||
| bn = BatchNorm2d(out_channels) | |||
| relu = ReLU() | |||
| fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) | |||
| bn.eval() | |||
| fused_conv.eval() | |||
| inputs = Tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
| expected_result = relu(bn(conv(inputs))) | |||
| actual_result = fused_conv(inputs) | |||
| np.testing.assert_allclose( | |||
| expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||
| ) | |||
| conv.eval() | |||
| bn.eval() | |||
| relu.eval() | |||
| fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) | |||
| fused_conv.eval() | |||
| expected_result = relu(conv(inputs)) | |||
| actual_result = fused_conv(inputs) | |||
| np.testing.assert_allclose( | |||
| expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||
| ) | |||
| conv.train() | |||
| bn.train() | |||
| fused_conv = fuse_conv_bn_relu_module(conv, bn, None) | |||
| fused_conv.train() | |||
| expected_result = bn(conv(inputs)) | |||
| actual_result = fused_conv(inputs) | |||
| np.testing.assert_allclose( | |||
| expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||
| ) | |||
| def test_ConvTransposeBn2d_fold_weight_bias(): | |||
| in_channels = 32 | |||
| out_channels = 64 | |||
| kernel_size = 3 | |||
| conv = ConvTranspose2d(in_channels, out_channels, kernel_size) | |||
| bn = BatchNorm2d(out_channels) | |||
| relu = ReLU() | |||
| fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) | |||
| bn.eval() | |||
| fused_conv.eval() | |||
| inputs = Tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
| expected_result = relu(bn(conv(inputs))) | |||
| actual_result = fused_conv(inputs) | |||
| np.testing.assert_allclose( | |||
| expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||
| ) | |||
| conv.eval() | |||
| bn.eval() | |||
| relu.eval() | |||
| fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) | |||
| fused_conv.eval() | |||
| expected_result = relu(conv(inputs)) | |||
| actual_result = fused_conv(inputs) | |||
| np.testing.assert_allclose( | |||
| expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||
| ) | |||
| conv.train() | |||
| bn.train() | |||
| fused_conv = fuse_conv_bn_relu_module(conv, bn, None) | |||
| fused_conv.train() | |||
| expected_result = bn(conv(inputs)) | |||
| actual_result = fused_conv(inputs) | |||
| np.testing.assert_allclose( | |||
| expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||
| ) | |||