diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index a0b2e5bdb2..31fbbb9651 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -102,13 +102,13 @@ class BatchNormFoldCell(Cell): return batch_mean, batch_std, running_mean, running_std -class FakeQuantWithMinMaxAscend(Cell): +class FakeQuantWithMinMax(Cell): r""" Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. Args: - min_init (int, list): The dimension of channel or 1(layer). Default: -6. - max_init (int, list): The dimension of channel or 1(layer). Default: 6. + min_init (int, float): The dimension of channel or 1(layer). Default: -6. + max_init (int, float): The dimension of channel or 1(layer). Default: 6. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. ema (bool): Exponential Moving Average algorithm update min and max. Default: False. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. @@ -143,32 +143,33 @@ class FakeQuantWithMinMaxAscend(Cell): symmetric=False, narrow_range=False, training=True): - """init FakeQuantWithMinMaxAscend layer""" - super(FakeQuantWithMinMaxAscend, self).__init__() + """init FakeQuantWithMinMax layer""" + super(FakeQuantWithMinMax, self).__init__() self.min_init = min_init self.max_init = max_init self.num_bits = num_bits self.ema = ema self.ema_decay = ema_decay self.per_channel = per_channel + self.out_channels = out_channels self.channel_axis = channel_axis self.quant_delay = quant_delay self.symmetric = symmetric self.narrow_range = narrow_range self.training = training + self.is_ascend = context.get_context('device_target') == "Ascend" # init tensor min and max for fake quant op - if isinstance(min_init, int): - min_array = np.array([min_init]).reshape(1).astype(np.float32) - max_array = np.array([max_init]).reshape(1).astype(np.float32) - elif isinstance(min_init, list): - min_array = np.array([self.min_init for i in range( - 0, self.out_channels)]).astype(np.float32) - max_array = np.array([self.max_init for i in range( - 0, self.out_channels)]).astype(np.float32) + if self.per_channel: + min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32) + max_array = np.array([self.max_init for i in range(0, self.out_channels)]).astype(np.float32) + else: + min_array = np.array([self.min_init]).reshape(1).astype(np.float32) + max_array = np.array([self.max_init]).reshape(1).astype(np.float32) self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) + # init fake quant relative op if per_channel: quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis) ema_fun = partial(P.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis) @@ -176,28 +177,36 @@ class FakeQuantWithMinMaxAscend(Cell): quant_fun = P.FakeQuantPerLayer ema_fun = P.FakeQuantMinMaxPerLayerUpdate - self.fake_quant = quant_fun(num_bits=self.num_bits, - ema=self.ema, - ema_decay=self.ema_decay, - quant_delay=self.quant_delay, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - training=self.training) - self.ema_update = ema_fun(num_bits=self.num_bits, - ema=self.ema, - ema_decay=self.ema_decay, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - training=self.training) + if self.is_ascend: + self.fake_quant = quant_fun(num_bits=self.num_bits, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + training=self.training) + else: + self.fake_quant = quant_fun(num_bits=self.num_bits, + ema=self.ema, + ema_decay=ema_decay, + quant_delay=quant_delay, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + training=self.training) + if self.ema: + self.ema_update = ema_fun(num_bits=self.num_bits, + ema=self.ema, + ema_decay=self.ema_decay, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + training=self.training) def extend_repr(self): - s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format( - self.min_init, self.max_init, self.ema, self.ema_decay, - self.per_channel, self.quant_delay, self.channel_axis) + s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \ + 'quant_delay={}, min_init={}, max_init={}'.format( + self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel, + self.channel_axis, self.out_channels, self.quant_delay, self.min_init, self.max_init) return s def construct(self, x): - if self.update: + if self.ema and self.is_ascend: min_up, max_up = self.ema_update(x, self.minq, self.maxq) out = self.fake_quant(x, min_up, max_up) P.Assign()(self.minq, min_up) @@ -207,104 +216,6 @@ class FakeQuantWithMinMaxAscend(Cell): return out -class FakeQuantWithMinMaxGPU(Cell): - r""" - Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. - - Args: - min_init (int, list): The dimension of channel or 1(layer). Default: -6. - max_init (int, list): The dimension of channel or 1(layer). Default: 6. - num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. - ema (bool): Exponential Moving Average algorithm update min and max. Default: False. - ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. - per_channel (bool): Quantization by layer or channel. Default: False. - out_channels (int): declarate the min and max channel size, Default: 1. - quant_delay (int): Quantization delay parameters according by global step. Default: 0. - symmetric (bool): Quantization algorithm use symmetric or not. Default: False. - narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. - - Inputs: - - **x** (Tensor) - The input of FakeQuantWithMinMax. - - Outputs: - Tensor, with the same type and shape as the `x`. - - Examples: - >>> fake_quant = FakeQuantWithMinMax() - >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) - >>> result = fake_quant(input_x) - """ - - def __init__(self, - min_init=-6, - max_init=6, - num_bits=8, - ema=False, - ema_decay=0.999, - per_channel=False, - channel_axis=1, - out_channels=1, - quant_delay=0, - symmetric=False, - narrow_range=False, - training=True): - super(FakeQuantWithMinMaxGPU, self).__init__() - self.min_init = min_init - self.max_init = max_init - self.num_bits = num_bits - self.ema = ema - self.ema_decay = ema_decay - self.per_channel = per_channel - self.channel_axis = channel_axis - self.quant_delay = quant_delay - self.symmetric = symmetric - self.narrow_range = narrow_range - self.training = training - - # init tensor min and max for fake quant op - if isinstance(min_init, int): - min_array = np.array([min_init]).reshape(1).astype(np.float32) - max_array = np.array([max_init]).reshape(1).astype(np.float32) - elif isinstance(min_init, list): - min_array = np.array([self.min_init for i in range( - 0, self.out_channels)]).astype(np.float32) - max_array = np.array([self.max_init for i in range( - 0, self.out_channels)]).astype(np.float32) - self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) - self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) - - if per_channel: - quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis) - else: - quant_fun = P.FakeQuantPerLayer - self.fake_quant = quant_fun(num_bits=self.num_bits, - ema=self.ema, - ema_decay=ema_decay, - quant_delay=quant_delay, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - training=self.training) - - def extend_repr(self): - s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format( - self.min_init, self.max_init, self.ema, self.ema_decay, - self.per_channel, self.quant_delay, self.channel_axis) - return s - - def construct(self, x): - out = self.fake_quant(x, self.minq, self.maxq) - return out - - -def FakeQuantWithMinMax(**kwargs): - if context.get_context('device_target') == "Ascend": - out = FakeQuantWithMinMaxAscend(**kwargs) - if context.get_context('device_target') == "GPU": - out = FakeQuantWithMinMaxGPU(**kwargs) - else: - raise ValueError("Not support platform or channel mode.") - return out - class Conv2dBatchNormQuant(Cell): r""" 2D convolution with BatchNormal op folded layer. diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py index 827d7a433c..f6c133c808 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py @@ -92,6 +92,7 @@ def fake_quant_perchannel(x, min_val, max_val, y, kernel_name="fake_quant_perchannel"): """FakeQuantPerChannel""" x_shape = x.get("shape") + x_shape_ = x.get("ori_shape") x_format = x.get("format") x_dtype = x.get("dtype") min_shape = min_val.get("ori_shape") @@ -101,8 +102,8 @@ def fake_quant_perchannel(x, min_val, max_val, y, util.check_kernel_name(kernel_name) util.check_shape_rule(x_shape) - util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis]) - util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis]) + util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis]) + util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis]) util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(max_shape) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py index 91fb694154..4e9053fcb1 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py @@ -117,6 +117,7 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx, kernel_name="fake_quant_perchannel_grad"): """FakeQuantPerChannelGrad""" x_shape = x.get("shape") + x_shape_ = x.get("ori_shape") x_format = x.get("format") x_dtype = x.get("dtype") min_shape = min_val.get("ori_shape") @@ -126,8 +127,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx, util.check_kernel_name(kernel_name) util.check_shape_rule(x_shape) - util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis]) - util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis]) + util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis]) + util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis]) util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(max_shape) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py index 81322acccf..20b39dc257 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py @@ -43,7 +43,7 @@ fake_quant_per_layer_op_info = TBERegOp("FakeQuantPerLayer") \ .get_op_info() -@op_info_register(fake_quant_op_info) +@op_info_register(fake_quant_per_layer_op_info) def _fake_quant_per_layer_tbe(): """FakeQuantPerLayer TBE register""" return diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index 6aa30ab2f3..b228c51b10 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -110,10 +110,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer): def infer_shape(self, x_shape, min_shape, max_shape): validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) - validator.check("min shape", min_shape, "max shape", - max_shape, Rel.EQ, self.name) - validator.check_integer("min rank", len( - min_shape), 1, Rel.EQ, self.name) + validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) + validator.check_integer("min shape", len(min_shape), 1, Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, min_type, max_type): @@ -168,7 +166,7 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer): x_shape, Rel.EQ, self.name) validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) - validator.check_integer("min rank", len( + validator.check_integer("min shape", len( min_shape), 1, Rel.EQ, self.name) return dout_shape @@ -255,10 +253,11 @@ class FakeQuantPerChannel(PrimitiveWithInfer): def infer_shape(self, x_shape, min_shape, max_shape): validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) + validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) validator.check_integer( - "min shape[0]", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) + "min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) validator.check_integer( - "max shape[0]", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) + "max shape", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, min_type, max_type): @@ -379,7 +378,7 @@ class BatchNormFold(PrimitiveWithInfer): def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape): validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name) validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name) - validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) + validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name) return mean_shape, mean_shape, mean_shape, mean_shape def infer_dtype(self, x_type, mean_type, variance_type, global_step_type): @@ -426,7 +425,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer): "batch_std shape", batch_std_shape, Rel.EQ, self.name) validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name) - validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) + validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name) return x_shape def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type, @@ -568,7 +567,7 @@ class BatchNormFold2(PrimitiveWithInfer): validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name) validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis], Rel.EQ, self.name) - validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) + validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type, @@ -616,7 +615,7 @@ class BatchNormFold2Grad(PrimitiveWithInfer): validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name) validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis], Rel.EQ, self.name) - validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) + validator.check_integer("global step shape len", len(global_step_shape), 1, Rel.EQ, self.name) return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape def infer_dtype(self, dout_type, x_type, gamma_type, @@ -887,7 +886,7 @@ class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer): validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) - validator.check_integer("min rank", len( + validator.check_integer("min shape", len( min_shape), 1, Rel.EQ, self.name) return min_shape, max_shape @@ -963,7 +962,7 @@ class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer): validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) - validator.check_integer("min rank", len( + validator.check_integer("min shape", len( min_shape), 1, Rel.EQ, self.name) return min_shape, max_shape