|
|
|
@@ -22,7 +22,7 @@ from mindspore.common.parameter import Parameter |
|
|
|
from mindspore.common.initializer import initializer |
|
|
|
from mindspore.common.tensor import Tensor |
|
|
|
from mindspore._checkparam import check_int_positive, check_bool, twice |
|
|
|
from mindspore._checkparam import Validator as validator |
|
|
|
from mindspore._checkparam import Validator as validator, Rel |
|
|
|
from mindspore.nn.cell import Cell |
|
|
|
from mindspore.nn.layer.activation import get_activation |
|
|
|
import mindspore.context as context |
|
|
|
@@ -207,7 +207,7 @@ class FakeQuantWithMinMaxD(Cell): |
|
|
|
|
|
|
|
class FakeQuantWithMinMax(Cell): |
|
|
|
r""" |
|
|
|
Aware Quantization training op. This OP provide Fake quantization observer function on data with min and max. |
|
|
|
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. |
|
|
|
|
|
|
|
Args: |
|
|
|
min_init (int, list): The dimension of channel or 1(layer). Default: -6. |
|
|
|
@@ -243,8 +243,7 @@ class FakeQuantWithMinMax(Cell): |
|
|
|
out_channels=1, |
|
|
|
quant_delay=0, |
|
|
|
symmetric=False, |
|
|
|
narrow_range=False, |
|
|
|
training=True): |
|
|
|
narrow_range=False): |
|
|
|
"""init FakeQuantWithMinMax layer""" |
|
|
|
super(FakeQuantWithMinMax, self).__init__() |
|
|
|
|
|
|
|
@@ -258,7 +257,6 @@ class FakeQuantWithMinMax(Cell): |
|
|
|
self.quant_delay = quant_delay |
|
|
|
self.symmetric = symmetric |
|
|
|
self.narrow_range = narrow_range |
|
|
|
self.training = training |
|
|
|
|
|
|
|
if per_channel: |
|
|
|
min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32) |
|
|
|
@@ -422,11 +420,13 @@ class Conv2dBatchNormQuant(Cell): |
|
|
|
self.per_channel = per_channel |
|
|
|
self.symmetric = symmetric |
|
|
|
self.narrow_range = narrow_range |
|
|
|
self.channel_axis = int(group > 1) |
|
|
|
self.is_gpu = context.get_context('device_target') == "GPU" |
|
|
|
|
|
|
|
# initialize convolution op and Parameter |
|
|
|
if context.get_context('device_target') == "Ascend" and group > 1: |
|
|
|
validator.check_integer('group', group, 'in_channels', in_channels, 'Conv2dBatchNormQuant') |
|
|
|
validator.check_integer('group', group, 'in_channels', out_channels, 'Conv2dBatchNormQuant') |
|
|
|
validator.check_integer('group', group, in_channels, Rel.EQ, 'Conv2dBatchNormQuant') |
|
|
|
validator.check_integer('group', group, out_channels, Rel.EQ, 'Conv2dBatchNormQuant') |
|
|
|
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, |
|
|
|
kernel_size=self.kernel_size, |
|
|
|
pad_mode=pad_mode, |
|
|
|
@@ -472,7 +472,7 @@ class Conv2dBatchNormQuant(Cell): |
|
|
|
symmetric=symmetric, |
|
|
|
narrow_range=narrow_range) |
|
|
|
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) |
|
|
|
self.correct_mul = P.CorrectionMul() |
|
|
|
self.correct_mul = P.CorrectionMul(self.channel_axis) |
|
|
|
if context.get_context('device_target') == "Ascend": |
|
|
|
self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn) |
|
|
|
self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0) |
|
|
|
|