|
|
@@ -372,7 +372,8 @@ class FakeQuantWithMinMax(Cell): |
|
|
if self.is_ascend: |
|
|
if self.is_ascend: |
|
|
self.fake_quant_train = quant_fun(num_bits=self.num_bits, |
|
|
self.fake_quant_train = quant_fun(num_bits=self.num_bits, |
|
|
symmetric=self.symmetric, |
|
|
symmetric=self.symmetric, |
|
|
narrow_range=self.narrow_range) |
|
|
|
|
|
|
|
|
narrow_range=self.narrow_range, |
|
|
|
|
|
quant_delay=self.quant_delay) |
|
|
self.fake_quant_infer = self.fake_quant_train |
|
|
self.fake_quant_infer = self.fake_quant_train |
|
|
else: |
|
|
else: |
|
|
quant_fun = partial(quant_fun, |
|
|
quant_fun = partial(quant_fun, |
|
|
@@ -679,28 +680,40 @@ class Conv2dBnWithoutFoldQuant(Cell): |
|
|
self.group = group |
|
|
self.group = group |
|
|
self.quant_delay = quant_delay |
|
|
self.quant_delay = quant_delay |
|
|
|
|
|
|
|
|
weight_shape = [out_channels, in_channels // group, *self.kernel_size] |
|
|
|
|
|
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') |
|
|
|
|
|
|
|
|
|
|
|
self.bias_add = P.BiasAdd() |
|
|
self.bias_add = P.BiasAdd() |
|
|
if check_bool(has_bias): |
|
|
if check_bool(has_bias): |
|
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') |
|
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') |
|
|
else: |
|
|
else: |
|
|
self.bias = None |
|
|
self.bias = None |
|
|
|
|
|
|
|
|
self.conv = P.Conv2D(out_channel=self.out_channels, |
|
|
|
|
|
kernel_size=self.kernel_size, |
|
|
|
|
|
mode=1, |
|
|
|
|
|
pad_mode=self.pad_mode, |
|
|
|
|
|
pad=self.padding, |
|
|
|
|
|
stride=self.stride, |
|
|
|
|
|
dilation=self.dilation, |
|
|
|
|
|
group=self.group) |
|
|
|
|
|
|
|
|
# initialize convolution op and Parameter |
|
|
|
|
|
if context.get_context('device_target') == "Ascend" and group > 1: |
|
|
|
|
|
validator.check_integer('group', group, in_channels, Rel.EQ) |
|
|
|
|
|
validator.check_integer('group', group, out_channels, Rel.EQ) |
|
|
|
|
|
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, |
|
|
|
|
|
kernel_size=self.kernel_size, |
|
|
|
|
|
pad_mode=pad_mode, |
|
|
|
|
|
pad=padding, |
|
|
|
|
|
stride=self.stride, |
|
|
|
|
|
dilation=self.dilation) |
|
|
|
|
|
weight_shape = [1, in_channels, *self.kernel_size] |
|
|
|
|
|
channel_axis = 1 |
|
|
|
|
|
else: |
|
|
|
|
|
self.conv = P.Conv2D(out_channel=self.out_channels, |
|
|
|
|
|
kernel_size=self.kernel_size, |
|
|
|
|
|
mode=1, |
|
|
|
|
|
pad_mode=self.pad_mode, |
|
|
|
|
|
pad=self.padding, |
|
|
|
|
|
stride=self.stride, |
|
|
|
|
|
dilation=self.dilation, |
|
|
|
|
|
group=self.group) |
|
|
|
|
|
weight_shape = [out_channels, in_channels // group, *self.kernel_size] |
|
|
|
|
|
channel_axis = 0 |
|
|
|
|
|
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') |
|
|
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, |
|
|
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, |
|
|
max_init=6, |
|
|
max_init=6, |
|
|
ema=False, |
|
|
ema=False, |
|
|
per_channel=per_channel, |
|
|
per_channel=per_channel, |
|
|
channel_axis=0, |
|
|
|
|
|
|
|
|
channel_axis=channel_axis, |
|
|
num_channels=out_channels, |
|
|
num_channels=out_channels, |
|
|
num_bits=num_bits, |
|
|
num_bits=num_bits, |
|
|
symmetric=symmetric, |
|
|
symmetric=symmetric, |
|
|
@@ -1009,6 +1022,7 @@ class ActQuant(_QuantActivation): |
|
|
def get_origin(self): |
|
|
def get_origin(self): |
|
|
return self.act |
|
|
return self.act |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LeakyReLUQuant(_QuantActivation): |
|
|
class LeakyReLUQuant(_QuantActivation): |
|
|
r""" |
|
|
r""" |
|
|
LeakyReLUQuant activation function. Add Fake Quant OP after HSwish OP. |
|
|
LeakyReLUQuant activation function. Add Fake Quant OP after HSwish OP. |
|
|
@@ -1078,7 +1092,6 @@ class LeakyReLUQuant(_QuantActivation): |
|
|
return self.act |
|
|
return self.act |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HSwishQuant(_QuantActivation): |
|
|
class HSwishQuant(_QuantActivation): |
|
|
r""" |
|
|
r""" |
|
|
HSwishQuant activation function. Add Fake Quant OP after HSwish OP. |
|
|
HSwishQuant activation function. Add Fake Quant OP after HSwish OP. |
|
|
|