| @@ -372,7 +372,8 @@ class FakeQuantWithMinMax(Cell): | |||
| if self.is_ascend: | |||
| self.fake_quant_train = quant_fun(num_bits=self.num_bits, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range) | |||
| narrow_range=self.narrow_range, | |||
| quant_delay=self.quant_delay) | |||
| self.fake_quant_infer = self.fake_quant_train | |||
| else: | |||
| quant_fun = partial(quant_fun, | |||
| @@ -679,28 +680,40 @@ class Conv2dBnWithoutFoldQuant(Cell): | |||
| self.group = group | |||
| self.quant_delay = quant_delay | |||
| weight_shape = [out_channels, in_channels // group, *self.kernel_size] | |||
| self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') | |||
| self.bias_add = P.BiasAdd() | |||
| if check_bool(has_bias): | |||
| self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') | |||
| else: | |||
| self.bias = None | |||
| self.conv = P.Conv2D(out_channel=self.out_channels, | |||
| kernel_size=self.kernel_size, | |||
| mode=1, | |||
| pad_mode=self.pad_mode, | |||
| pad=self.padding, | |||
| stride=self.stride, | |||
| dilation=self.dilation, | |||
| group=self.group) | |||
| # initialize convolution op and Parameter | |||
| if context.get_context('device_target') == "Ascend" and group > 1: | |||
| validator.check_integer('group', group, in_channels, Rel.EQ) | |||
| validator.check_integer('group', group, out_channels, Rel.EQ) | |||
| self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, | |||
| kernel_size=self.kernel_size, | |||
| pad_mode=pad_mode, | |||
| pad=padding, | |||
| stride=self.stride, | |||
| dilation=self.dilation) | |||
| weight_shape = [1, in_channels, *self.kernel_size] | |||
| channel_axis = 1 | |||
| else: | |||
| self.conv = P.Conv2D(out_channel=self.out_channels, | |||
| kernel_size=self.kernel_size, | |||
| mode=1, | |||
| pad_mode=self.pad_mode, | |||
| pad=self.padding, | |||
| stride=self.stride, | |||
| dilation=self.dilation, | |||
| group=self.group) | |||
| weight_shape = [out_channels, in_channels // group, *self.kernel_size] | |||
| channel_axis = 0 | |||
| self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') | |||
| self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, | |||
| max_init=6, | |||
| ema=False, | |||
| per_channel=per_channel, | |||
| channel_axis=0, | |||
| channel_axis=channel_axis, | |||
| num_channels=out_channels, | |||
| num_bits=num_bits, | |||
| symmetric=symmetric, | |||
| @@ -1009,6 +1022,7 @@ class ActQuant(_QuantActivation): | |||
| def get_origin(self): | |||
| return self.act | |||
| class LeakyReLUQuant(_QuantActivation): | |||
| r""" | |||
| LeakyReLUQuant activation function. Add Fake Quant OP after HSwish OP. | |||
| @@ -1078,7 +1092,6 @@ class LeakyReLUQuant(_QuantActivation): | |||
| return self.act | |||
| class HSwishQuant(_QuantActivation): | |||
| r""" | |||
| HSwishQuant activation function. Add Fake Quant OP after HSwish OP. | |||