| @@ -24,8 +24,7 @@ class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule | |||||
| # get fold bn conv_transpose2d param | # get fold bn conv_transpose2d param | ||||
| gamma = self.bn.weight | gamma = self.bn.weight | ||||
| if gamma is None: | if gamma is None: | ||||
| gamma = ones((self.bn.num_features), dtype="float32") | |||||
| gamma = gamma.reshape(1, -1, 1, 1) | |||||
| gamma = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||||
| beta = self.bn.bias | beta = self.bn.bias | ||||
| if beta is None: | if beta is None: | ||||
| beta = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | beta = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | ||||
| @@ -44,10 +43,10 @@ class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule | |||||
| bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | ||||
| scale_factor = gamma * bn_istd | scale_factor = gamma * bn_istd | ||||
| if self.conv_transpose2d.groups == 1: | if self.conv_transpose2d.groups == 1: | ||||
| w_fold = self.conv_transpose2d.weight * scale_factor.reshape(-1, 1, 1, 1) | |||||
| w_fold = self.conv_transpose2d.weight * scale_factor.reshape(1, -1, 1, 1) | |||||
| else: | else: | ||||
| w_fold = self.conv_transpose2d.weight * scale_factor.reshape( | w_fold = self.conv_transpose2d.weight * scale_factor.reshape( | ||||
| self.conv_transpose2d.groups, -1, 1, 1, 1 | |||||
| self.conv_transpose2d.groups, 1, -1, 1, 1 | |||||
| ) | ) | ||||
| w_fold = self.apply_quant_weight(w_fold) | w_fold = self.apply_quant_weight(w_fold) | ||||
| @@ -32,15 +32,21 @@ _MAP_TO_FUSED_MODULE = { | |||||
| def fold_weight_bias( | def fold_weight_bias( | ||||
| weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False | weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False | ||||
| ): | ): | ||||
| shape = (1, -1, 1, 1) | |||||
| shape = (-1, 1, 1, 1) | |||||
| if transpose: | if transpose: | ||||
| shape = (-1, 1, 1, 1) | |||||
| shape = (1, -1, 1, 1) | |||||
| kernel_shape = weight.shape | kernel_shape = weight.shape | ||||
| if len(kernel_shape) == 5: | if len(kernel_shape) == 5: | ||||
| groups, num_features = kernel_shape[0], kernel_shape[1] | |||||
| if transpose: | |||||
| groups, num_features = kernel_shape[0], kernel_shape[2] | |||||
| else: | |||||
| groups, num_features = kernel_shape[0], kernel_shape[1] | |||||
| else: | else: | ||||
| groups, num_features = 1, kernel_shape[0] | |||||
| if transpose: | |||||
| groups, num_features = 1, kernel_shape[1] | |||||
| else: | |||||
| groups, num_features = 1, kernel_shape[0] | |||||
| out_channels = groups * num_features | out_channels = groups * num_features | ||||
| if gamma is None: | if gamma is None: | ||||
| @@ -93,12 +99,37 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): | |||||
| compute_mode=conv.compute_mode, | compute_mode=conv.compute_mode, | ||||
| name=conv.name, | name=conv.name, | ||||
| ) | ) | ||||
| new_conv = module if bn is None or not conv.training else module.conv | |||||
| if isinstance(conv, ConvTranspose2d): | |||||
| module.output_padding = conv.output_padding | |||||
| new_conv = ( | |||||
| module if bn is None or not conv.training else module.conv_transpose2d | |||||
| ) | |||||
| else: | |||||
| new_conv = module if bn is None or not conv.training else module.conv | |||||
| weight, bias = conv.weight, conv.bias | weight, bias = conv.weight, conv.bias | ||||
| if not conv.training and bn is not None: | if not conv.training and bn is not None: | ||||
| weight, bias = fold_weight_bias( | |||||
| weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps, | |||||
| ) | |||||
| if isinstance(conv, ConvTranspose2d): | |||||
| weight, bias = fold_weight_bias( | |||||
| weight, | |||||
| bias, | |||||
| bn.weight, | |||||
| bn.bias, | |||||
| bn.running_mean, | |||||
| bn.running_var, | |||||
| bn.eps, | |||||
| transpose=True, | |||||
| ) | |||||
| else: | |||||
| weight, bias = fold_weight_bias( | |||||
| weight, | |||||
| bias, | |||||
| bn.weight, | |||||
| bn.bias, | |||||
| bn.running_mean, | |||||
| bn.running_var, | |||||
| bn.eps, | |||||
| ) | |||||
| new_conv.weight = Parameter(weight) | new_conv.weight = Parameter(weight) | ||||
| if bias is not None: | if bias is not None: | ||||
| new_conv.bias = Parameter(bias) | new_conv.bias = Parameter(bias) | ||||
| @@ -106,55 +137,3 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): | |||||
| module.bn = deepcopy(bn) | module.bn = deepcopy(bn) | ||||
| new_conv.training = conv.training | new_conv.training = conv.training | ||||
| return module | return module | ||||
| def fuse_conv_transpose2d_bn_relu_module( | |||||
| conv_transpose2d: ConvTranspose2d, bn: BatchNorm2d, relu: ReLU | |||||
| ): | |||||
| module_key = tuple([type(m) for m in [conv_transpose2d, bn, relu] if m]) | |||||
| if bn: | |||||
| assert ( | |||||
| conv_transpose2d.training == bn.training | |||||
| ), "ConvTranspose2d and BN both must be in the same mode (train or eval)." | |||||
| assert ( | |||||
| bn.num_features == conv_transpose2d.out_channels | |||||
| ), "Output channel of ConvTranspose2d must match num_features of BatchNorm2d" | |||||
| module_key = module_key + (conv_transpose2d.training,) | |||||
| module = _MAP_TO_FUSED_MODULE[module_key]( | |||||
| in_channels=conv_transpose2d.in_channels, | |||||
| out_channels=conv_transpose2d.out_channels, | |||||
| kernel_size=conv_transpose2d.kernel_size, | |||||
| stride=conv_transpose2d.stride, | |||||
| padding=conv_transpose2d.padding, | |||||
| output_padding=conv_transpose2d.output_padding, | |||||
| dilation=conv_transpose2d.dilation, | |||||
| groups=conv_transpose2d.groups, | |||||
| bias=conv_transpose2d.bias is not None, | |||||
| conv_mode=conv_transpose2d.conv_mode, | |||||
| compute_mode=conv_transpose2d.compute_mode, | |||||
| name=conv_transpose2d.name, | |||||
| ) | |||||
| new_conv_transpose2d = ( | |||||
| module | |||||
| if bn is None or not conv_transpose2d.training | |||||
| else module.conv_transpose2d | |||||
| ) | |||||
| weight, bias = conv_transpose2d.weight, conv_transpose2d.bias | |||||
| if not conv_transpose2d.training and bn is not None: | |||||
| weight, bias = fold_weight_bias( | |||||
| weight, | |||||
| bias, | |||||
| bn.weight, | |||||
| bn.bias, | |||||
| bn.running_mean, | |||||
| bn.running_var, | |||||
| bn.eps, | |||||
| transpose=False, | |||||
| ) | |||||
| new_conv_transpose2d.weight = Parameter(weight) | |||||
| if bias is not None: | |||||
| new_conv_transpose2d.bias = Parameter(bias) | |||||
| if bn is not None and conv_transpose2d.training: | |||||
| module.bn = deepcopy(bn) | |||||
| new_conv_transpose2d.training = conv_transpose2d.training | |||||
| return module | |||||
| @@ -34,35 +34,49 @@ def test_qat_convbn2d(): | |||||
| in_channels = 32 | in_channels = 32 | ||||
| out_channels = 64 | out_channels = 64 | ||||
| kernel_size = 3 | kernel_size = 3 | ||||
| class TestNet(Module): | |||||
| def __init__(self, groups, bias): | |||||
| super().__init__() | |||||
| self.quant = QuantStub() | |||||
| self.dequant = DequantStub() | |||||
| self.conv_bn = ConvBn2d( | |||||
| in_channels, out_channels, kernel_size, groups=groups, bias=bias, | |||||
| ) | |||||
| def forward(self, inp): | |||||
| out = self.quant(inp) | |||||
| out = self.conv_bn(out) | |||||
| out = self.dequant(out) | |||||
| return out | |||||
| inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
| for groups, bias in product([1, 4], [True, False]): | for groups, bias in product([1, 4], [True, False]): | ||||
| module = ConvBn2d( | |||||
| in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||||
| ) | |||||
| M.init.normal_(module.bn.weight) | |||||
| M.init.normal_(module.bn.bias) | |||||
| module.train() | |||||
| qat_module = quantize_qat(module, inplace=False) | |||||
| disable_fake_quant(qat_module) | |||||
| inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
| normal_outputs = module(inputs) | |||||
| qat_outputs = qat_module(inputs) | |||||
| net = TestNet(groups, bias) | |||||
| net.train() | |||||
| qat_net = quantize_qat(net, inplace=False) | |||||
| disable_fake_quant(qat_net) | |||||
| normal_outputs = net(inputs) | |||||
| qat_outputs = qat_net(inputs) | |||||
| np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-4, | |||||
| ) | ) | ||||
| np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
| module.bn.running_mean.numpy(), | |||||
| qat_module.bn.running_mean.numpy(), | |||||
| net.conv_bn.bn.running_mean.numpy(), | |||||
| qat_net.conv_bn.bn.running_mean.numpy(), | |||||
| atol=5e-8, | atol=5e-8, | ||||
| ) | ) | ||||
| np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
| module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7, | |||||
| net.conv_bn.bn.running_var.numpy(), | |||||
| qat_net.conv_bn.bn.running_var.numpy(), | |||||
| atol=5e-7, | |||||
| ) | ) | ||||
| module.eval() | |||||
| normal_outputs = module(inputs) | |||||
| qat_module.eval() | |||||
| qat_outputs = qat_module(inputs) | |||||
| net.eval() | |||||
| normal_outputs = net(inputs) | |||||
| qat_net.eval() | |||||
| qat_outputs = qat_net(inputs) | |||||
| np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-4, | |||||
| ) | ) | ||||
| @@ -70,40 +84,44 @@ def test_qat_convtransposebn2d(): | |||||
| in_channels = 32 | in_channels = 32 | ||||
| out_channels = 64 | out_channels = 64 | ||||
| kernel_size = 3 | kernel_size = 3 | ||||
| class TestNet(Module): | |||||
| def __init__(self, groups, bias): | |||||
| super().__init__() | |||||
| self.quant = QuantStub() | |||||
| self.dequant = DequantStub() | |||||
| self.conv_transpose_bn = ConvTransposeBn2d( | |||||
| in_channels, out_channels, kernel_size, groups=groups, bias=bias, | |||||
| ) | |||||
| def forward(self, inp): | |||||
| out = self.quant(inp) | |||||
| out = self.conv_transpose_bn(out) | |||||
| out = self.dequant(out) | |||||
| return out | |||||
| for groups, bias in product([1, 4], [True, False]): | for groups, bias in product([1, 4], [True, False]): | ||||
| module = ConvTransposeBn2d( | |||||
| in_channels=in_channels, | |||||
| out_channels=out_channels, | |||||
| kernel_size=kernel_size, | |||||
| output_padding=0, | |||||
| groups=groups, | |||||
| bias=bias, | |||||
| ) | |||||
| M.init.normal_(module.bn.weight) | |||||
| M.init.normal_(module.bn.bias) | |||||
| module.train() | |||||
| qat_module = quantize_qat(module, inplace=False) | |||||
| disable_fake_quant(qat_module) | |||||
| net = TestNet(groups, bias) | |||||
| net.train() | |||||
| qat_net = quantize_qat(net, inplace=False) | |||||
| disable_fake_quant(qat_net) | |||||
| inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | ||||
| normal_outputs = module(inputs) | |||||
| qat_outputs = qat_module(inputs) | |||||
| np.testing.assert_allclose( | |||||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||||
| ) | |||||
| normal_outputs = net(inputs) | |||||
| qat_outputs = qat_net(inputs) | |||||
| np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
| module.bn.running_mean.numpy(), | |||||
| qat_module.bn.running_mean.numpy(), | |||||
| atol=5e-8, | |||||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-5, | |||||
| ) | ) | ||||
| np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
| module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7, | |||||
| net.conv_transpose_bn.bn.running_var.numpy(), | |||||
| qat_net.conv_transpose_bn.bn.running_var.numpy(), | |||||
| atol=5e-7, | |||||
| ) | ) | ||||
| module.eval() | |||||
| normal_outputs = module(inputs) | |||||
| qat_module.eval() | |||||
| qat_outputs = qat_module(inputs) | |||||
| net.eval() | |||||
| normal_outputs = net(inputs) | |||||
| qat_net.eval() | |||||
| qat_outputs = qat_net(inputs) | |||||
| np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||||
| normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-5, | |||||
| ) | ) | ||||
| @@ -3,6 +3,15 @@ import pytest | |||||
| from megengine import Parameter, Tensor | from megengine import Parameter, Tensor | ||||
| from megengine import module as Float | from megengine import module as Float | ||||
| from megengine.functional import ones, zeros | |||||
| from megengine.module import ( | |||||
| BatchNorm2d, | |||||
| Conv2d, | |||||
| ConvBn2d, | |||||
| ConvTranspose2d, | |||||
| ConvTransposeBn2d, | |||||
| ReLU, | |||||
| ) | |||||
| from megengine.module import qat as QAT | from megengine.module import qat as QAT | ||||
| from megengine.module import quantized as Q | from megengine.module import quantized as Q | ||||
| from megengine.quantization import ( | from megengine.quantization import ( | ||||
| @@ -24,6 +33,7 @@ from megengine.quantization.quantize import ( | |||||
| quantize_qat, | quantize_qat, | ||||
| reset_qconfig, | reset_qconfig, | ||||
| ) | ) | ||||
| from megengine.utils.bn_fusion import fuse_conv_bn_relu_module | |||||
| class FloatNet(Float.Module): | class FloatNet(Float.Module): | ||||
| @@ -291,3 +301,85 @@ def test_convert_with_custom_mapping(): | |||||
| net = Net() | net = Net() | ||||
| qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) | qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) | ||||
| assert isinstance(qat_net.example, QATExample) | assert isinstance(qat_net.example, QATExample) | ||||
| def test_ConvBn2d_fold_weight_bias(): | |||||
| in_channels = 32 | |||||
| out_channels = 64 | |||||
| kernel_size = 3 | |||||
| conv = Conv2d(in_channels, out_channels, kernel_size) | |||||
| bn = BatchNorm2d(out_channels) | |||||
| relu = ReLU() | |||||
| fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) | |||||
| bn.eval() | |||||
| fused_conv.eval() | |||||
| inputs = Tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
| expected_result = relu(bn(conv(inputs))) | |||||
| actual_result = fused_conv(inputs) | |||||
| np.testing.assert_allclose( | |||||
| expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||||
| ) | |||||
| conv.eval() | |||||
| bn.eval() | |||||
| relu.eval() | |||||
| fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) | |||||
| fused_conv.eval() | |||||
| expected_result = relu(conv(inputs)) | |||||
| actual_result = fused_conv(inputs) | |||||
| np.testing.assert_allclose( | |||||
| expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||||
| ) | |||||
| conv.train() | |||||
| bn.train() | |||||
| fused_conv = fuse_conv_bn_relu_module(conv, bn, None) | |||||
| fused_conv.train() | |||||
| expected_result = bn(conv(inputs)) | |||||
| actual_result = fused_conv(inputs) | |||||
| np.testing.assert_allclose( | |||||
| expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||||
| ) | |||||
| def test_ConvTransposeBn2d_fold_weight_bias(): | |||||
| in_channels = 32 | |||||
| out_channels = 64 | |||||
| kernel_size = 3 | |||||
| conv = ConvTranspose2d(in_channels, out_channels, kernel_size) | |||||
| bn = BatchNorm2d(out_channels) | |||||
| relu = ReLU() | |||||
| fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) | |||||
| bn.eval() | |||||
| fused_conv.eval() | |||||
| inputs = Tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
| expected_result = relu(bn(conv(inputs))) | |||||
| actual_result = fused_conv(inputs) | |||||
| np.testing.assert_allclose( | |||||
| expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||||
| ) | |||||
| conv.eval() | |||||
| bn.eval() | |||||
| relu.eval() | |||||
| fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) | |||||
| fused_conv.eval() | |||||
| expected_result = relu(conv(inputs)) | |||||
| actual_result = fused_conv(inputs) | |||||
| np.testing.assert_allclose( | |||||
| expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||||
| ) | |||||
| conv.train() | |||||
| bn.train() | |||||
| fused_conv = fuse_conv_bn_relu_module(conv, bn, None) | |||||
| fused_conv.train() | |||||
| expected_result = bn(conv(inputs)) | |||||
| actual_result = fused_conv(inputs) | |||||
| np.testing.assert_allclose( | |||||
| expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||||
| ) | |||||