| @@ -419,12 +419,17 @@ class Conv2dBnFoldQuantOneConv(Cell): | |||
| self.fake = fake | |||
| self.quant_config = quant_config | |||
| self.quant_dtype = quant_dtype | |||
| data_format = 'NCHW' | |||
| self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) | |||
| self.is_gpu = context.get_context('device_target') == "GPU" | |||
| self.is_Ascend = context.get_context('device_target') == "Ascend" | |||
| self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE | |||
| if context.get_context("enable_ge"): | |||
| self.is_ge_backend = True | |||
| else: | |||
| self.is_ge_backend = False | |||
| self.enable_default_train = self.is_graph_mode and \ | |||
| (self.is_ge_backend or self.is_ascend) | |||
| # initialize convolution op and Parameter | |||
| if context.get_context('device_target') == "Ascend" and group > 1: | |||
| @@ -448,6 +453,7 @@ class Conv2dBnFoldQuantOneConv(Cell): | |||
| group=group) | |||
| weight_shape = [out_channels, in_channels // group, *self.kernel_size] | |||
| channel_axis = 0 | |||
| self.channel_axis = channel_axis | |||
| self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') | |||
| self.bias_add = P.BiasAdd() | |||
| if Validator.check_bool(has_bias): | |||
| @@ -490,7 +496,6 @@ class Conv2dBnFoldQuantOneConv(Cell): | |||
| self.mul_var = P.Mul().shard(data_parallel_strategy_one) | |||
| self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy) | |||
| self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy) | |||
| self.one = Tensor(1, mstype.int32) | |||
| self.reshape = P.Reshape() | |||
| def extend_repr(self): | |||
| @@ -507,18 +512,22 @@ class Conv2dBnFoldQuantOneConv(Cell): | |||
| def construct(self, x): | |||
| running_std = P.Sqrt()(P.TensorAdd()(self.moving_variance, self.eps)) | |||
| scale_factor = self.gamma / running_std | |||
| weight = self.weight * scale_factor | |||
| if self.channel_axis: | |||
| scale_factor = self.reshape(scale_factor, (1, -1, 1, 1)) | |||
| else: | |||
| scale_factor = self.reshape(scale_factor, (-1, 1, 1, 1)) | |||
| weight = self.weight * scale_factor | |||
| if self.fake: | |||
| weight = self.fake_quant_weight(weight) | |||
| conv = self.conv(x, weight) | |||
| scale_factor = self.reshape(scale_factor, (1, -1, 1, 1)) | |||
| conv_orig = conv / scale_factor | |||
| if self.enable_default_train: | |||
| scale_factor = P.Reciprocal()(scale_factor) | |||
| conv_orig = conv * scale_factor | |||
| else: | |||
| conv_orig = conv / scale_factor | |||
| if self.training: | |||
| if not self.is_gpu: | |||
| if self.enable_default_train: | |||
| out, batch_mean, batch_var, _, _ = self.bn_train(conv_orig, | |||
| self.gamma, | |||
| self.beta, | |||
| @@ -531,20 +540,19 @@ class Conv2dBnFoldQuantOneConv(Cell): | |||
| temp_variance = self.mul_var(mean_sub2, self.momentum) | |||
| out = F.depend(out, self.assign_sub_mean(self.moving_mean, temp_mean)) | |||
| out = F.depend(out, self.assign_sub_var(self.moving_variance, temp_variance)) | |||
| else: | |||
| out = self.bn_train(conv_orig, | |||
| self.gamma, | |||
| self.beta, | |||
| self.moving_mean, | |||
| self.moving_variance)[0] | |||
| else: | |||
| out = self.bn_infer(conv_orig, | |||
| self.gamma, | |||
| self.beta, | |||
| self.moving_mean, | |||
| self.moving_variance)[0] | |||
| return out | |||
| return out | |||
| return self.bn_train(conv_orig, | |||
| self.gamma, | |||
| self.beta, | |||
| self.moving_mean, | |||
| self.moving_variance)[0] | |||
| return self.bn_infer(conv_orig, | |||
| self.gamma, | |||
| self.beta, | |||
| self.moving_mean, | |||
| self.moving_variance)[0] | |||
| class Conv2dBnFoldQuant(Cell): | |||