diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 15cf5b58c0..570c760d15 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -27,10 +27,8 @@ from mindspore.nn.cell import Cell from mindspore.nn.layer.activation import get_activation import mindspore.context as context - __all__ = [ 'FakeQuantWithMinMax', - 'DepthwiseConv2dBatchNormQuant', 'Conv2dBatchNormQuant', 'Conv2dQuant', 'DenseQuant', @@ -113,7 +111,7 @@ class FakeQuantWithMinMaxD(Cell): max_init (int, list): The dimension of channel or 1(layer). Default: 6. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. ema (bool): Exponential Moving Average algorithm update min and max. Default: False. - ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999. + ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization by layer or channel. Default: False. out_channels (int): declarate the min and max channel size, Default: 1. quant_delay (int): Quantization delay parameters according by global step. Default: 0. @@ -131,6 +129,7 @@ class FakeQuantWithMinMaxD(Cell): >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) >>> result = fake_quant(input_x) """ + def __init__(self, min_init=-6, max_init=6, @@ -215,7 +214,7 @@ class FakeQuantWithMinMax(Cell): max_init (int, list): The dimension of channel or 1(layer). Default: 6. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. ema (bool): Exponential Moving Average algorithm update min and max. Default: False. - ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999. + ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. per_channel (bool): Quantization by layer or channel. Default: False. out_channels (int): declarate the min and max channel size, Default: 1. quant_delay (int): Quantization delay parameters according by global step. Default: 0. @@ -335,191 +334,6 @@ class FakeQuantWithMinMax(Cell): return out -class DepthwiseConv2dBatchNormQuant(Cell): - r""" - 2D depthwise convolution with BatchNormal op folded layer. - - For a more Detailed overview of Conv2d op. - - Args: - in_channels (int): The number of input channel :math:`C_{in}`. - out_channels (int): The number of output channel :math:`C_{out}`. - kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window. - stride (int): Specifies stride for all spatial dimensions with the same value. - pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". - padding: (int): Implicit paddings on both sides of the input. Default: 0. - eps (int): Parameters for BatchNormal. Default: 1e-5. - momentum (int): Parameters for BatchNormal op. Default: 0.9. - weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the - convolution kernel. Default: 'None'. - beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the - beta vector. Default: 'None'. - gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the - gamma vector. Default: 'None'. - mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the - mean vector. Default: 'None'. - var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the - variance vector. Default: 'None'. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. - freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. - fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. - symmetric (bool): Quantization algorithm use symmetric or not. Default: False. - narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. - - Inputs: - - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - - Outputs: - Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. - - Examples: - >>> quant = nn.DepthwiseConv2dBatchNormQuant(1, 6, - kernel_size= (2, 2), - stride=(1, 1), - pad_mode="valid", - >>> dilation=(1, 1)) - >>> input_x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mindspore.float32) - >>> result = quant(input_x) - """ - - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - pad_mode='same', - padding=0, - dilation=1, - group=1, - eps=1e-5, - momentum=0.997, - weight_init=None, - beta_init=None, - gamma_init=None, - mean_init=None, - var_init=None, - quant_delay=0, - freeze_bn=100000, - fake=True, - num_bits=8, - per_channel=False, - symmetric=False, - narrow_range=False): - """init DepthwiseConv2dBatchNormQuant layer""" - super(DepthwiseConv2dBatchNormQuant, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.pad_mode = pad_mode - self.padding = padding - self.dilation = twice(dilation) - self.stride = twice(stride) - self.group = group - self.fake = fake - self.freeze_bn = freeze_bn - self.momentum = momentum - self.quant_delay = quant_delay - if isinstance(kernel_size, int): - self.kernel_size = (kernel_size, kernel_size) - else: - self.kernel_size = kernel_size - if group > 1: - validator.check_integer('group', group, 'in_channels', in_channels, 'Conv2dBatchNormQuant') - validator.check_integer('group', group, 'in_channels', out_channels, 'Conv2dBatchNormQuant') - self.is_depthwise = group > 1 - - channel_multiplier = out_channels // in_channels - self.conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, - kernel_size=kernel_size, - stride=stride, - pad_mode=pad_mode, - pad=padding) - - if weight_init is None: - weight_init = initializer('normal', [channel_multiplier, in_channels, *kernel_size]) - self.weight = Parameter(weight_init, name='weight') - if gamma_init is None: - gamma_init = initializer('ones', [out_channels]) - self.gamma = Parameter(gamma_init, name='gamma') - if beta_init is None: - beta_init = initializer('zeros', [out_channels]) - self.beta = Parameter(beta_init, name='beta') - if mean_init is None: - mean_init = initializer('zeros', [out_channels]) - self.moving_mean = Parameter( - mean_init, name='moving_mean', requires_grad=False) - if var_init is None: - var_init = initializer('ones', [out_channels]) - self.moving_variance = Parameter( - var_init, name='moving_variance', requires_grad=False) - - self.step = Parameter(initializer( - 'normal', [1], dtype=mstype.int32), name='step', requires_grad=False) - - self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, - max_init=6, - ema=False, - num_bits=num_bits, - quant_delay=quant_delay, - per_channel=per_channel, - out_channels=out_channels, - symmetric=symmetric, - narrow_range=narrow_range) - self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) - - self.correct_mul = P.CorrectionMul(self.is_depthwise) - if context.get_context('device_target') == "Ascend": - self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn) - self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0) - elif context.get_context('device_target') == "GPU": - self.batchnorm_fold2_train = P.BatchNormFold2(freeze_bn=freeze_bn) - self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0) - else: - raise ValueError("Not support platform.") - self.one = Tensor(1, mstype.int32) - self.assignadd = P.AssignAdd() - self.is_gpu = context.get_context('device_target') == "GPU" - - def extend_repr(self): - s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ - 'pad_mode={}, padding={}, dilation={}, group={}, ' \ - 'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format( - self.in_channels, self.out_channels, self.kernel_size, self.stride, - self.pad_mode, self.padding, self.dilation, self.group, - self.fake, self.freeze_bn, self.momentum, self.quant_delay) - return s - - def construct(self, x): - out_conv = self.conv(x, self.weight) - # BN fold1 - batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold(out_conv, - self.moving_mean, - self.moving_variance, - self.step) - # fake weight - weight = self.correct_mul(self.weight, self.gamma, running_std) - if self.fake: - weight = self.fake_quant_weight(weight) - out = self.conv(x, weight) - # BN fold2 - if self.is_gpu: - if self.training: - out = self.batchnorm_fold2_train(out, self.beta, self.gamma, - batch_std, batch_mean, running_std, running_mean, self.step) - F.control_depend(out, self.assignadd(self.step, self.one)) - else: - out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, - batch_std, batch_mean, running_std, running_mean, self.step) - else: - if self.training: - out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std) - F.control_depend(out, self.assignadd(self.step, self.one)) - else: - out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std) - return out - - class Conv2dBatchNormQuant(Cell): r""" 2D convolution with BatchNormal op folded layer. @@ -593,23 +407,47 @@ class Conv2dBatchNormQuant(Cell): super(Conv2dBatchNormQuant, self).__init__() self.in_channels = in_channels self.out_channels = out_channels + self.kernel_size = twice(kernel_size) + self.stride = twice(stride) self.pad_mode = pad_mode self.padding = padding self.dilation = twice(dilation) - self.stride = twice(stride) self.group = group - self.fake = fake - self.freeze_bn = freeze_bn + self.eps = eps self.momentum = momentum self.quant_delay = quant_delay - if isinstance(kernel_size, int): - self.kernel_size = (kernel_size, kernel_size) + self.freeze_bn = freeze_bn + self.fake = fake + self.num_bits = num_bits + self.per_channel = per_channel + self.symmetric = symmetric + self.narrow_range = narrow_range + + # initialize convolution op and Parameter + if context.get_context('device_target') == "Ascend" and group > 1: + validator.check_integer('group', group, 'in_channels', in_channels, 'Conv2dBatchNormQuant') + validator.check_integer('group', group, 'in_channels', out_channels, 'Conv2dBatchNormQuant') + self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, + kernel_size=self.kernel_size, + pad_mode=pad_mode, + pad=padding, + stride=self.stride, + dilation=self.dilation) + if weight_init is None: + weight_init = initializer('normal', [1, in_channels, *self.kernel_size]) else: - self.kernel_size = kernel_size - if weight_init is None: - weight_init = initializer( - 'normal', [out_channels, in_channels // group, *self.kernel_size]) + self.conv = P.Conv2D(out_channel=out_channels, + kernel_size=self.kernel_size, + pad_mode=pad_mode, + pad=padding, + stride=self.stride, + dilation=self.dilation, + group=group) + if weight_init is None: + weight_init = initializer('normal', [out_channels, in_channels // group, *self.kernel_size]) self.weight = Parameter(weight_init, name='weight') + + # initialize batchnorm Parameter if gamma_init is None: gamma_init = initializer('ones', [out_channels]) self.gamma = Parameter(gamma_init, name='gamma') @@ -618,16 +456,12 @@ class Conv2dBatchNormQuant(Cell): self.beta = Parameter(beta_init, name='beta') if mean_init is None: mean_init = initializer('zeros', [out_channels]) - self.moving_mean = Parameter( - mean_init, name='moving_mean', requires_grad=False) + self.moving_mean = Parameter(mean_init, name='moving_mean', requires_grad=False) if var_init is None: var_init = initializer('ones', [out_channels]) - self.moving_variance = Parameter( - var_init, name='moving_variance', requires_grad=False) - - self.step = Parameter(initializer( - 'normal', [1], dtype=mstype.int32), name='step', requires_grad=False) + self.moving_variance = Parameter(var_init, name='moving_variance', requires_grad=False) + # initialize fake ops self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, max_init=6, ema=False, @@ -638,14 +472,6 @@ class Conv2dBatchNormQuant(Cell): symmetric=symmetric, narrow_range=narrow_range) self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) - self.conv = P.Conv2D(out_channel=out_channels, - kernel_size=kernel_size, - mode=1, - pad_mode=pad_mode, - pad=padding, - stride=stride, - dilation=1, - group=group) self.correct_mul = P.CorrectionMul() if context.get_context('device_target') == "Ascend": self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn) @@ -654,7 +480,8 @@ class Conv2dBatchNormQuant(Cell): self.batchnorm_fold2_train = P.BatchNormFold2(freeze_bn=freeze_bn) self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0) else: - raise ValueError("Not support platform.") + raise ValueError("Unsupported platform: {}".format(context.get_context('device_target'))) + self.step = Parameter(initializer('normal', [1], dtype=mstype.int32), name='step', requires_grad=False) self.one = Tensor(1, mstype.int32) self.assignadd = P.AssignAdd() @@ -926,6 +753,7 @@ class ReLUQuant(Cell): Args: num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. quant_delay (int): Quantization delay parameters according by global step. Default: 0. + ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -944,6 +772,7 @@ class ReLUQuant(Cell): def __init__(self, num_bits=8, quant_delay=0, + ema_decay=0.999, symmetric=False, narrow_range=False): super(ReLUQuant, self).__init__() @@ -952,6 +781,7 @@ class ReLUQuant(Cell): num_bits=num_bits, quant_delay=quant_delay, ema=True, + ema_decay=ema_decay, symmetric=symmetric, narrow_range=narrow_range) self.relu = P.ReLU() @@ -973,6 +803,7 @@ class ReLU6Quant(Cell): Args: num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. quant_delay (int): Quantization delay parameters according by global step. Default: 0. + ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -988,7 +819,11 @@ class ReLU6Quant(Cell): >>> result = relu6_quant(input_x) """ - def __init__(self, num_bits=8, quant_delay=0, symmetric=False, + def __init__(self, + num_bits=8, + quant_delay=0, + ema_decay=0.999, + symmetric=False, narrow_range=False): super(ReLU6Quant, self).__init__() self.fake_quant_act = FakeQuantWithMinMax(min_init=0, @@ -996,6 +831,7 @@ class ReLU6Quant(Cell): num_bits=num_bits, quant_delay=quant_delay, ema=True, + ema_decay=ema_decay, symmetric=symmetric, narrow_range=narrow_range) self.relu6 = P.ReLU6() @@ -1015,6 +851,7 @@ class HSwishQuant(Cell): Args: num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. quant_delay (int): Quantization delay parameters according by global step. Default: 0. + ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -1033,6 +870,7 @@ class HSwishQuant(Cell): def __init__(self, num_bits=8, quant_delay=0, + ema_decay=0.999, symmetric=False, narrow_range=False): super(HSwishQuant, self).__init__() @@ -1041,6 +879,7 @@ class HSwishQuant(Cell): num_bits=num_bits, quant_delay=quant_delay, ema=True, + ema_decay=ema_decay, symmetric=symmetric, narrow_range=narrow_range) self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, @@ -1048,6 +887,7 @@ class HSwishQuant(Cell): num_bits=num_bits, quant_delay=quant_delay, ema=True, + ema_decay=ema_decay, symmetric=symmetric, narrow_range=narrow_range) self.act = P.HSwish() @@ -1068,6 +908,7 @@ class HSigmoidQuant(Cell): Args: num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. quant_delay (int): Quantization delay parameters according by global step. Default: 0. + ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -1086,6 +927,7 @@ class HSigmoidQuant(Cell): def __init__(self, num_bits=8, quant_delay=0, + ema_decay=0.999, symmetric=False, narrow_range=False): super(HSigmoidQuant, self).__init__() @@ -1101,6 +943,7 @@ class HSigmoidQuant(Cell): num_bits=num_bits, quant_delay=quant_delay, ema=True, + ema_decay=ema_decay, symmetric=symmetric, narrow_range=narrow_range) self.act = P.HSigmoid() @@ -1121,6 +964,7 @@ class TensorAddQuant(Cell): Args: num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. quant_delay (int): Quantization delay parameters according by global step. Default: 0. + ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -1140,6 +984,7 @@ class TensorAddQuant(Cell): def __init__(self, num_bits=8, quant_delay=0, + ema_decay=0.999, symmetric=False, narrow_range=False): super(TensorAddQuant, self).__init__() @@ -1148,6 +993,7 @@ class TensorAddQuant(Cell): num_bits=num_bits, quant_delay=quant_delay, ema=True, + ema_decay=ema_decay, symmetric=symmetric, narrow_range=narrow_range) self.add = P.TensorAdd() @@ -1167,6 +1013,7 @@ class MulQuant(Cell): Args: num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. quant_delay (int): Quantization delay parameters according by global step. Default: 0. + ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -1181,6 +1028,7 @@ class MulQuant(Cell): def __init__(self, num_bits=8, quant_delay=0, + ema_decay=0.999, symmetric=False, narrow_range=False): super(MulQuant, self).__init__() @@ -1189,6 +1037,7 @@ class MulQuant(Cell): num_bits=num_bits, quant_delay=quant_delay, ema=True, + ema_decay=ema_decay, symmetric=symmetric, narrow_range=narrow_range) self.mul = P.Mul() diff --git a/mindspore/ops/_grad/grad_quant_ops.py b/mindspore/ops/_grad/grad_quant_ops.py index 1e694a7dba..0a5cd54306 100644 --- a/mindspore/ops/_grad/grad_quant_ops.py +++ b/mindspore/ops/_grad/grad_quant_ops.py @@ -85,7 +85,7 @@ def get_bprop_batchnorm_fold2(self): @bprop_getters.register(P.BatchNormFoldD) def get_bprop_BatchNormFold(self): """Generate bprop for BatchNormFold for Ascend""" - op = P.BatchNormFoldGrad_(self.epsilon, self.is_training, self.freeze_bn) + op = P.BatchNormFoldGradD(self.epsilon, self.is_training, self.freeze_bn) def bprop(x, x_sum, x_square_sum, mean, variance, out, dout): dx = op(dout[1], dout[2], x, out[1], out[2]) diff --git a/mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py b/mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py index 63b9e2b7d2..39549ccfcc 100644 --- a/mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +++ b/mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py @@ -16,6 +16,7 @@ """_BatchNormFold op""" from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType +import te from te import tvm from topi import generic from topi.cce import util diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index 7cc03e2c0f..c0196d4806 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -31,8 +31,11 @@ __all__ = ["FakeQuantWithMinMax", "BatchNormFold2", "BatchNormFold2Grad", "BatchNormFoldD", + "BatchNormFoldGradD", "BNTrainingReduce", "BatchNormFold2_D", + "BatchNormFold2GradD", + "BatchNormFold2GradReduce", "FakeQuantWithMinMaxUpdate", ]