diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 0bb69b4476..79661f72e1 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -419,8 +419,11 @@ class Conv2dBnFoldQuant(Cell): padding (int): Implicit paddings on both sides of the input. Default: 0. eps (float): Parameters for BatchNormal. Default: 1e-5. momentum (float): Parameters for BatchNormal op. Default: 0.997. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the + bias vector. Default: 'zeros'. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta vector. Default: 'zeros'. gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the @@ -460,7 +463,9 @@ class Conv2dBnFoldQuant(Cell): group=1, eps=1e-5, momentum=0.997, + has_bias=False, weight_init='normal', + bias_init='zeros', beta_init='zeros', gamma_init='ones', mean_init='zeros', @@ -484,6 +489,7 @@ class Conv2dBnFoldQuant(Cell): self.group = group self.eps = eps self.momentum = momentum + self.has_bias = has_bias self.quant_delay = quant_delay self.freeze_bn = freeze_bn self.fake = fake @@ -516,6 +522,11 @@ class Conv2dBnFoldQuant(Cell): weight_shape = [out_channels, in_channels // group, *self.kernel_size] channel_axis = 0 self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') + self.bias_add = P.BiasAdd() + if check_bool(has_bias): + self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') + else: + self.bias = None # initialize BatchNorm Parameter self.gamma = Parameter(initializer(gamma_init, [out_channels]), name='gamma') @@ -562,6 +573,8 @@ class Conv2dBnFoldQuant(Cell): def construct(self, x): out_conv = self.conv(x, self.weight) + if self.has_bias: + out_conv = self.bias_add(out_conv, self.bias) # BN fold1 batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold(out_conv, self.moving_mean, @@ -572,6 +585,8 @@ class Conv2dBnFoldQuant(Cell): if self.fake: weight = self.fake_quant_weight(weight) out = self.conv(x, weight) + if self.has_bias: + out = self.bias_add(out, self.bias) # BN fold2 if self.is_gpu: if self.training: