From b070ab1592a6dd03bcd68bd98bb9fc6fb64e7559 Mon Sep 17 00:00:00 2001 From: xiaoyisd Date: Thu, 12 Nov 2020 21:04:25 +0800 Subject: [PATCH] add Conv2dBnFoldQuantOneConv --- mindspore/nn/layer/combined.py | 2 +- mindspore/nn/layer/quant.py | 217 ++++++++++++++++++++++++++++++++- 2 files changed, 217 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/layer/combined.py b/mindspore/nn/layer/combined.py index 208cc904b4..a5582f6ab2 100644 --- a/mindspore/nn/layer/combined.py +++ b/mindspore/nn/layer/combined.py @@ -97,7 +97,7 @@ class Conv2dBnAct(Cell): weight_init='normal', bias_init='zeros', has_bn=False, - momentum=0.9, + momentum=0.997, eps=1e-5, activation=None, alpha=0.2, diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index cfb779c381..0449e81429 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -34,6 +34,7 @@ from ...ops.operations import _quant_ops as Q __all__ = [ 'FakeQuantWithMinMaxObserver', + 'Conv2dBnFoldQuantOneConv', 'Conv2dBnFoldQuant', 'Conv2dBnWithoutFoldQuant', 'Conv2dQuant', @@ -330,6 +331,220 @@ QuantConfig = namedtuple("QuantConfig", ['weight', 'activation']) quant_config_default = QuantConfig(weight=FakeQuantWithMinMaxObserver, activation=FakeQuantWithMinMaxObserver) +class Conv2dBnFoldQuantOneConv(Cell): + r""" + 2D convolution with BatchNormal op folded construct. + + This part is a more detailed overview of Conv2d op. + + Args: + in_channels (int): The number of input channel :math:`C_{in}`. + out_channels (int): The number of output channel :math:`C_{out}`. + kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window. + stride (int): Specifies stride for all spatial dimensions with the same value. + pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". + padding (int): Implicit paddings on both sides of the input. Default: 0. + eps (float): Parameters for BatchNormal. Default: 1e-5. + momentum (float): Parameters for BatchNormal op. Default: 0.997. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the + convolution kernel. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the + bias vector. Default: 'zeros'. + beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the + beta vector. Default: 'zeros'. + gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the + gamma vector. Default: 'ones'. + mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the + mean vector. Default: 'zeros'. + var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the + variance vector. Default: 'ones'. + fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMaxObserver. Default: True. + quant_config (QuantConfig): Configs the oberser types and quant configs of weight and activation. Default: + both set to default FakeQuantWithMinMaxObserver. + quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + Examples: + >>> qconfig = compression.quant.create_quant_config() + >>> conv2d_bnfold = nn.Conv2dBnFoldQuant(1, 6, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid", + >>> quant_config=qconfig) + >>> input = Tensor(np.random.randint(-2, 2, (2, 1, 3, 3)), mindspore.float32) + >>> result = conv2d_bnfold(input) + >>> result.shape + (2, 6, 2, 2) + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + pad_mode='same', + padding=0, + dilation=1, + group=1, + eps=1e-5, + momentum=0.997, + has_bias=False, + weight_init='normal', + bias_init='zeros', + beta_init='zeros', + gamma_init='ones', + mean_init='zeros', + var_init='ones', + fake=True, + quant_config=quant_config_default, + quant_dtype=QuantDtype.INT8): + """Initialize Conv2dBnFoldQuant layer""" + super(Conv2dBnFoldQuantOneConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = twice(kernel_size) + self.stride = twice(stride) + self.pad_mode = pad_mode + self.padding = padding + self.dilation = twice(dilation) + self.group = group + self.eps = eps + self.momentum = momentum + self.has_bias = has_bias + self.fake = fake + self.quant_config = quant_config + self.quant_dtype = quant_dtype + self.is_gpu = context.get_context('device_target') == "GPU" + self.is_Ascend = context.get_context('device_target') == "Ascend" + if context.get_context("enable_ge"): + self.is_ge_backend = True + else: + self.is_ge_backend = False + + # initialize convolution op and Parameter + if context.get_context('device_target') == "Ascend" and group > 1: + Validator.check_equal_int(group, in_channels, 'group') + Validator.check_equal_int(group, out_channels, 'group') + self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, + kernel_size=self.kernel_size, + pad_mode=pad_mode, + pad=padding, + stride=self.stride, + dilation=self.dilation) + weight_shape = [1, in_channels, *self.kernel_size] + channel_axis = 1 + else: + self.conv = P.Conv2D(out_channel=out_channels, + kernel_size=self.kernel_size, + pad_mode=pad_mode, + pad=padding, + stride=self.stride, + dilation=self.dilation, + group=group) + weight_shape = [out_channels, in_channels // group, *self.kernel_size] + channel_axis = 0 + self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') + self.bias_add = P.BiasAdd() + if Validator.check_bool(has_bias): + self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') + else: + self.bias = None + + # initialize BatchNorm Parameter + self.gamma = Parameter(initializer(gamma_init, [out_channels]), name='gamma') + self.beta = Parameter(initializer(beta_init, [out_channels]), name='beta') + self.moving_mean = Parameter(initializer(mean_init, [out_channels]), name='moving_mean', requires_grad=False) + self.moving_variance = Parameter(initializer(var_init, [out_channels]), name='moving_variance', + requires_grad=False) + + # initialize fake ops + self.fake_quant_weight = quant_config.weight(min_init=-6, + max_init=6, + ema=False, + channel_axis=channel_axis, + num_channels=out_channels, + quant_dtype=quant_dtype) + if self.is_graph_mode and (self.is_ge_backend or self.is_ascend): + self.bn_train = P.BatchNorm(is_training=True, + epsilon=self.eps) + elif self.is_gpu: + self.bn_train = P.FusedBatchNormEx(mode=1, + epsilon=self.eps, + momentum=self.momentum, + data_format=self.format) + else: + self.bn_train = P.FusedBatchNorm(mode=1, + epsilon=self.eps, + momentum=self.momentum) + self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) + data_parallel_strategy = ((1,), (1,)) + data_parallel_strategy_one = ((1,), ()) + self.sub_mean = P.Sub().shard(data_parallel_strategy) + self.sub_var = P.Sub().shard(data_parallel_strategy) + self.mul_mean = P.Mul().shard(data_parallel_strategy_one) + self.mul_var = P.Mul().shard(data_parallel_strategy_one) + self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy) + self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy) + self.one = Tensor(1, mstype.int32) + self.reshape = P.Reshape() + + def extend_repr(self): + s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ + 'pad_mode={}, padding={}, dilation={}, group={}, ' \ + 'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(self.in_channels, self.out_channels, + self.kernel_size, self.stride, + self.pad_mode, self.padding, self.dilation, + self.group, + self.fake, self.freeze_bn, self.momentum, + self.fake_quant_weight.quant_delay) + return s + + def construct(self, x): + running_std = P.Sqrt()(P.TensorAdd()(self.moving_variance, self.eps)) + scale_factor = self.gamma / running_std + weight = self.weight * scale_factor + if self.channel_axis: + scale_factor = self.reshape(scale_factor, (1, -1, 1, 1)) + else: + scale_factor = self.reshape(scale_factor, (-1, 1, 1, 1)) + if self.fake: + weight = self.fake_quant_weight(weight) + conv = self.conv(x, weight) + scale_factor = self.reshape(scale_factor, (1, -1, 1, 1)) + conv_orig = conv / scale_factor + if self.training: + if not self.is_gpu: + out, batch_mean, batch_var, _, _ = self.bn_train(conv_orig, + self.gamma, + self.beta, + None, + None) + + mean_sub = self.sub_mean(self.moving_mean, batch_mean) + temp_mean = self.mul_mean(mean_sub, self.momentum) + mean_sub2 = self.sub_var(self.moving_variance, batch_var) + temp_variance = self.mul_var(mean_sub2, self.momentum) + out = F.depend(out, self.assign_sub_mean(self.moving_mean, temp_mean)) + out = F.depend(out, self.assign_sub_var(self.moving_variance, temp_variance)) + else: + out = self.bn_train(conv_orig, + self.gamma, + self.beta, + self.moving_mean, + self.moving_variance)[0] + else: + out = self.bn_infer(conv_orig, + self.gamma, + self.beta, + self.moving_mean, + self.moving_variance)[0] + + return out + + class Conv2dBnFoldQuant(Cell): r""" 2D convolution with BatchNormal op folded construct. @@ -627,7 +842,7 @@ class Conv2dBnWithoutFoldQuant(Cell): channel_axis=channel_axis, num_channels=out_channels, quant_dtype=quant_dtype) - self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum) + self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=1-momentum) def construct(self, x): weight = self.fake_quant_weight(self.weight)