From 9e7099d9abf756b0b6bd6cf1dcbbd25f067c5a32 Mon Sep 17 00:00:00 2001 From: chenzomi Date: Mon, 18 May 2020 16:52:21 +0800 Subject: [PATCH] fix bug in nn quant --- mindspore/nn/cell.py | 4 +- mindspore/nn/layer/quant.py | 343 ++++++++++++++++++++++-------------- 2 files changed, 215 insertions(+), 132 deletions(-) diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index c951606207..13dac375e3 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -97,9 +97,9 @@ class Cell: After invoked, can get all the cell's children's name prefix by '_param_prefix'. """ - cells = self.cells_and_names() + cells_name = self.cells_and_names() - for cell_name, cell in cells: + for cell_name, cell in cells_name: cell._param_prefix = cell_name @cell_init_args.setter diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index fe1c0e9f45..31df421bc0 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -15,7 +15,6 @@ """Aware quantization.""" import numpy as np -import mindspore.nn as nn import mindspore.common.dtype as mstype from mindspore.ops import operations as P from mindspore.ops import functional as F @@ -24,7 +23,6 @@ from mindspore.common.initializer import initializer from mindspore.common.tensor import Tensor from mindspore._checkparam import check_int_positive, check_bool, twice from mindspore.nn.cell import Cell -from mindspore.nn.layer.conv import _Conv from mindspore.nn.layer.activation import get_activation __all__ = [ @@ -37,6 +35,7 @@ __all__ = [ 'HSwishQuant', 'HSigmoidQuant', 'TensorAddQuant', + 'MulQuant', ] @@ -51,7 +50,7 @@ class FakeQuantWithMinMax(Cell): ema (bool): Exponential Moving Average algorithm update min and max. Default: False. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999. per_channel (bool): Quantization by layer or channel. Default: False. - channel_size (int): declarate the min and max channel size, Default: 1. + out_channels (int): declarate the min and max channel size, Default: 1. quant_delay (int): Quantization delay parameters according by global step. Default: 0. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. @@ -71,7 +70,7 @@ class FakeQuantWithMinMax(Cell): ema=False, ema_decay=0.999, per_channel=False, - channel_size=1, + out_channels=1, quant_delay=0, symmetric=False, narrow_range=False): @@ -83,16 +82,16 @@ class FakeQuantWithMinMax(Cell): self.ema = ema self.ema_decay = ema_decay self.per_channel = per_channel - self.channel_size = channel_size + self.out_channels = out_channels self.quant_delay = quant_delay self.symmetric = symmetric self.narrow_range = narrow_range if per_channel: min_array = np.array([self.min_init for i in range( - 0, self.channel_size)]).astype(np.float32) + 0, self.out_channels)]).astype(np.float32) max_array = np.array([self.max_init for i in range( - 0, self.channel_size)]).astype(np.float32) + 0, self.out_channels)]).astype(np.float32) self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, ema=self.ema, ema_decay=self.ema_decay, @@ -102,8 +101,8 @@ class FakeQuantWithMinMax(Cell): training=True) self.fake_quant_infer = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, ema=self.ema, - ema_decay=ema_decay, - quant_delay=quant_delay, + ema_decay=self.ema_decay, + quant_delay=self.quant_delay, symmetric=self.symmetric, narrow_range=self.narrow_range, training=False) @@ -119,28 +118,27 @@ class FakeQuantWithMinMax(Cell): training=True) self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits, ema=self.ema, - ema_decay=ema_decay, - quant_delay=quant_delay, + ema_decay=self.ema_decay, + quant_delay=self.quant_delay, symmetric=self.symmetric, narrow_range=self.narrow_range, training=False) - self.min = Parameter( + self.minq = Parameter( Tensor(min_array), name='quant_min', requires_grad=False) - self.max = Parameter( + self.maxq = Parameter( Tensor(max_array), name='quant_max', requires_grad=False) def extend_repr(self): - s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format( - self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size, - self.quant_delay) + s = 'min={}, max={}, ema={}, ema_decay={}, per_channel={}, quant_delay={}'.format( + self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.quant_delay) return s def construct(self, x): if self.training: - out = self.fake_quant_train(x, self.min, self.max) + out = self.fake_quant_train(x, self.minq, self.maxq) else: - out = self.fake_quant_infer(x, self.min, self.max) + out = self.fake_quant_infer(x, self.minq, self.maxq) return out @@ -188,13 +186,13 @@ class Conv2dBatchNormQuant(Cell): in_channels, out_channels, kernel_size, - stride, - pad_mode, + stride=1, + pad_mode='same', padding=0, dilation=1, group=1, eps=1e-5, - momentum=0.9, + momentum=0.997, weight_init=None, beta_init=None, gamma_init=None, @@ -208,24 +206,25 @@ class Conv2dBatchNormQuant(Cell): symmetric=False, narrow_range=False): super(Conv2dBatchNormQuant, self).__init__() - _ = dilation - self.stride = stride - self.conv = P.Conv2D(out_channel=out_channels, - kernel_size=kernel_size, - mode=1, - pad_mode=pad_mode, - pad=padding, - stride=stride, - dilation=1, - group=group) + self.in_channels = in_channels + self.out_channels = out_channels + self.pad_mode = pad_mode + self.padding = padding + self.dilation = twice(dilation) + self.stride = twice(stride) + self.group = group self.fake = fake self.freeze_bn = freeze_bn + self.momentum = momentum + self.quant_delay = quant_delay if isinstance(kernel_size, int): - kernel_size = (kernel_size, kernel_size) + self.kernel_size = (kernel_size, kernel_size) + else: + self.kernel_size = kernel_size if weight_init is None: weight_init = initializer( - 'normal', [out_channels, in_channels // group, *kernel_size]) + 'normal', [out_channels, in_channels // group, *self.kernel_size]) self.weight = Parameter(weight_init, name='weight') if gamma_init is None: gamma_init = initializer('ones', [out_channels]) @@ -245,16 +244,23 @@ class Conv2dBatchNormQuant(Cell): self.step = Parameter(initializer( 'normal', [1], dtype=mstype.int32), name='step', requires_grad=False) - self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6, - max_init=6, - ema=False, - num_bits=num_bits, - quant_delay=quant_delay, - per_channel=per_channel, - channel_size=out_channels, - symmetric=symmetric, - narrow_range=narrow_range) - + self.conv = P.Conv2D(out_channel=self.out_channels, + kernel_size=self.kernel_size, + mode=1, + pad_mode=self.pad_mode, + pad=self.padding, + stride=self.stride, + dilation=self.dilation, + group=self.group) + self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, + max_init=6, + ema=False, + num_bits=num_bits, + quant_delay=quant_delay, + per_channel=per_channel, + out_channels=out_channels, + symmetric=symmetric, + narrow_range=narrow_range) self.batchnorm_fold_train = P.BatchNormFold(epsilon=eps, momentum=momentum, is_training=True, @@ -271,7 +277,12 @@ class Conv2dBatchNormQuant(Cell): self.assignadd = P.AssignAdd() def extend_repr(self): - s = 'fake={}, freeze_bn={}'.format(self.fake, self.freeze_bn) + s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \ + 'pad_mode={}, padding={}, dilation={}, group={}, ' \ + 'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format( + self.in_channels, self.out_channels, self.kernel_size, self.stride, + self.pad_mode, self.padding, self.dilation, self.group, + self.fake, self.freeze_bn, self.momentum, self.quant_delay) return s def construct(self, x): @@ -295,9 +306,8 @@ class Conv2dBatchNormQuant(Cell): F.control_depend(out, self.assignadd(self.step, self.one)) else: step = self.step - out_conv = self.conv(x, self.weight) batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_infer( - out_conv, self.moving_mean, self.moving_variance, step) + x, self.moving_mean, self.moving_variance, step) weight = self.correct_mul(self.weight, self.gamma, running_std) if self.fake: weight = self.fake_quant_weight(weight) @@ -307,7 +317,7 @@ class Conv2dBatchNormQuant(Cell): return out -class Conv2dQuant(_Conv): +class Conv2dQuant(Cell): r""" 2D convolution with fake quant op layer. @@ -325,8 +335,8 @@ class Conv2dQuant(_Conv): divisible by the number of groups. Default: 1. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. - Default: 'normal'. - bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. + Default: None. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: None. quant_delay (int): Quantization delay parameters according by global step. Default: 0. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. @@ -351,40 +361,72 @@ class Conv2dQuant(_Conv): dilation=1, group=1, has_bias=False, - weight_init='normal', - bias_init='zeros', + weight_init=None, + bias_init=None, quant_delay=0, num_bits=8, per_channel=False, symmetric=False, narrow_range=False): - kernel_size = twice(kernel_size) - super(Conv2dQuant, self).__init__(in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation, - group, has_bias, weight_init, bias_init) - self.conv2d = P.Conv2D(out_channel=self.out_channels, kernel_size=self.kernel_size, mode=1, - pad_mode=self.pad_mode, pad=self.padding, stride=self.stride, dilation=self.dilation, - group=self.group) - self.bias_add = P.BiasAdd() - if pad_mode not in ('valid', 'same', 'pad'): - raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed ' - + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') - self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6, - max_init=6, - ema=False, - num_bits=num_bits, - quant_delay=quant_delay, - per_channel=per_channel, - channel_size=out_channels, - symmetric=symmetric, - narrow_range=narrow_range) + super(Conv2dQuant, self).__init__() + if isinstance(kernel_size, int): + self.kernel_size = (kernel_size, kernel_size) + else: + self.kernel_size = kernel_size + self.in_channels = check_int_positive(in_channels) + self.out_channels = check_int_positive(out_channels) + self.has_bias = has_bias + self.stride = twice(stride) + self.dilation = twice(dilation) + self.pad_mode = pad_mode + self.padding = padding + self.group = group + self.quant_delay = quant_delay + + if weight_init is None: + weight_init = initializer( + 'normal', [out_channels, in_channels // group, *self.kernel_size]) + self.weight = Parameter(weight_init, name='weight') + if bias_init is None: + bias_init = initializer('zeros', [out_channels]) + if has_bias: + self.bias = Parameter(bias_init, name='bias') + self.bias_add = P.BiasAdd() + + self.conv = P.Conv2D(out_channel=self.out_channels, + kernel_size=self.kernel_size, + mode=1, + pad_mode=self.pad_mode, + pad=self.padding, + stride=self.stride, + dilation=self.dilation, + group=self.group) + self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, + max_init=6, + ema=False, + num_bits=num_bits, + quant_delay=quant_delay, + per_channel=per_channel, + out_channels=out_channels, + symmetric=symmetric, + narrow_range=narrow_range) def construct(self, x): - weight_q = self.fake_quant_weight(self.weight) - out = self.conv2d(x, weight_q) + weight = self.fake_quant_weight(self.weight) + out = self.conv(x, weight) if self.has_bias: return self.bias_add(out, self.bias) return out + def extend_repr(self): + s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \ + 'pad_mode={}, padding={}, dilation={}, group={}, ' \ + 'has_bias={}, quant_delay={}'.format( + self.in_channels, self.out_channels, self.kernel_size, self.stride, + self.pad_mode, self.padding, self.dilation, self.group, + self.has_bias, self.quant_delay) + return s + class DenseQuant(Cell): r""" @@ -453,15 +495,15 @@ class DenseQuant(Cell): self.activation = get_activation(activation) self.activation_flag = self.activation is not None - self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6, - max_init=6, - ema=False, - num_bits=num_bits, - quant_delay=quant_delay, - per_channel=per_channel, - channel_size=out_channels, - symmetric=symmetric, - narrow_range=narrow_range) + self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, + max_init=6, + ema=False, + num_bits=num_bits, + quant_delay=quant_delay, + per_channel=per_channel, + out_channels=out_channels, + symmetric=symmetric, + narrow_range=narrow_range) def construct(self, x): """Use operators to construct to Dense layer.""" @@ -511,13 +553,13 @@ class ReLUQuant(Cell): symmetric=False, narrow_range=False): super(ReLUQuant, self).__init__() - self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0, - max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, - ema=True, - symmetric=symmetric, - narrow_range=narrow_range) + self.fake_quant_act = FakeQuantWithMinMax(min_init=0, + max_init=6, + num_bits=num_bits, + quant_delay=quant_delay, + ema=True, + symmetric=symmetric, + narrow_range=narrow_range) self.relu = P.ReLU() def construct(self, x): @@ -551,13 +593,13 @@ class ReLU6Quant(Cell): def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False): super(ReLU6Quant, self).__init__() - self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0, - max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, - ema=True, - symmetric=symmetric, - narrow_range=narrow_range) + self.fake_quant_act = FakeQuantWithMinMax(min_init=0, + max_init=6, + num_bits=num_bits, + quant_delay=quant_delay, + ema=True, + symmetric=symmetric, + narrow_range=narrow_range) self.relu6 = P.ReLU6() def construct(self, x): @@ -592,20 +634,20 @@ class HSwishQuant(Cell): symmetric=False, narrow_range=False): super(HSwishQuant, self).__init__() - self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0, - max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, - ema=True, - symmetric=symmetric, - narrow_range=narrow_range) - self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0, - max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, - ema=True, - symmetric=symmetric, - narrow_range=narrow_range) + self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, + max_init=6, + num_bits=num_bits, + quant_delay=quant_delay, + ema=True, + symmetric=symmetric, + narrow_range=narrow_range) + self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, + max_init=6, + num_bits=num_bits, + quant_delay=quant_delay, + ema=True, + symmetric=symmetric, + narrow_range=narrow_range) self.act = P.HSwish() def construct(self, x): @@ -641,20 +683,20 @@ class HSigmoidQuant(Cell): symmetric=False, narrow_range=False): super(HSigmoidQuant, self).__init__() - self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0, - max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, - ema=True, - symmetric=symmetric, - narrow_range=narrow_range) - self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0, - max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, - ema=True, - symmetric=symmetric, - narrow_range=narrow_range) + self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, + max_init=6, + num_bits=num_bits, + quant_delay=quant_delay, + ema=True, + symmetric=symmetric, + narrow_range=narrow_range) + self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, + max_init=6, + num_bits=num_bits, + quant_delay=quant_delay, + ema=True, + symmetric=symmetric, + narrow_range=narrow_range) self.act = P.HSigmoid() def construct(self, x): @@ -690,16 +732,57 @@ class TensorAddQuant(Cell): symmetric=False, narrow_range=False): super(TensorAddQuant, self).__init__() - self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=-6, - max_init=6, - num_bits=num_bits, - quant_delay=quant_delay, - ema=True, - symmetric=symmetric, - narrow_range=narrow_range) + self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, + max_init=6, + num_bits=num_bits, + quant_delay=quant_delay, + ema=True, + symmetric=symmetric, + narrow_range=narrow_range) self.add = P.TensorAdd() def construct(self, x1, x2): x = self.add(x1, x2) x = self.fake_quant_act(x) return x + + +class MulQuant(Cell): + r""" + Add Fake Quant OP after Mul OP. + + For a more Detailed overview of Mul op. + + Args: + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + + Inputs: + - **x** (Tensor) - The input of MulQuant. + + Outputs: + Tensor, with the same type and shape as the `x`. + + """ + + def __init__(self, + num_bits=8, + quant_delay=0, + symmetric=False, + narrow_range=False): + super(MulQuant, self).__init__() + self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, + max_init=6, + num_bits=num_bits, + quant_delay=quant_delay, + ema=True, + symmetric=symmetric, + narrow_range=narrow_range) + self.mul = P.Mul() + + def construct(self, x1, x2): + x = self.mul(x1, x2) + x = self.fake_quant_act(x) + return x