| @@ -171,6 +171,6 @@ bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std | |||||
| return true; | return true; | ||||
| } | } | ||||
| MS_REG_GPU_KERNEL(FakeQuantWithMinMax, FakeQuantGpuKernel) | |||||
| MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantGpuKernel) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -153,6 +153,6 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const | |||||
| return true; | return true; | ||||
| } | } | ||||
| MS_REG_GPU_KERNEL(FakeQuantWithMinMaxGrad, FakeQuantGradGpuKernel) | |||||
| MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantGradGpuKernel) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -175,6 +175,6 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs, | |||||
| return true; | return true; | ||||
| } | } | ||||
| MS_REG_GPU_KERNEL(FakeQuantWithMinMaxPerChannel, FakeQuantPerChannelGpuKernel) | |||||
| MS_REG_GPU_KERNEL(FakeQuantPerChannel, FakeQuantPerChannelGpuKernel) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -143,6 +143,6 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp | |||||
| return true; | return true; | ||||
| } | } | ||||
| MS_REG_GPU_KERNEL(FakeQuantWithMinMaxPerChannelGrad, FakeQuantPerChannelGradGpuKernel) | |||||
| MS_REG_GPU_KERNEL(FakeQuantPerChannelGrad, FakeQuantPerChannelGradGpuKernel) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Aware quantization.""" | """Aware quantization.""" | ||||
| from functools import partial | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| @@ -101,10 +102,9 @@ class BatchNormFoldCell(Cell): | |||||
| return batch_mean, batch_std, running_mean, running_std | return batch_mean, batch_std, running_mean, running_std | ||||
| class FakeQuantWithMinMaxD(Cell): | |||||
| class FakeQuantWithMinMaxAscend(Cell): | |||||
| r""" | r""" | ||||
| Aware Quantization training op of ascend. This OP provide Fake quantization observer | |||||
| function on data with min and max. | |||||
| Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. | |||||
| Args: | Args: | ||||
| min_init (int, list): The dimension of channel or 1(layer). Default: -6. | min_init (int, list): The dimension of channel or 1(layer). Default: -6. | ||||
| @@ -125,7 +125,7 @@ class FakeQuantWithMinMaxD(Cell): | |||||
| Tensor, with the same type and shape as the `x`. | Tensor, with the same type and shape as the `x`. | ||||
| Examples: | Examples: | ||||
| >>> fake_quant = nn.FakeQuantWithMinMaxD() | |||||
| >>> fake_quant = FakeQuantWithMinMax() | |||||
| >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) | >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) | ||||
| >>> result = fake_quant(input_x) | >>> result = fake_quant(input_x) | ||||
| """ | """ | ||||
| @@ -137,75 +137,77 @@ class FakeQuantWithMinMaxD(Cell): | |||||
| ema=False, | ema=False, | ||||
| ema_decay=0.999, | ema_decay=0.999, | ||||
| per_channel=False, | per_channel=False, | ||||
| channel_size=1, | |||||
| channel_axis=1, | |||||
| out_channels=1, | |||||
| quant_delay=0, | quant_delay=0, | ||||
| symmetric=False, | symmetric=False, | ||||
| narrow_range=False, | narrow_range=False, | ||||
| training=True): | training=True): | ||||
| """init FakeQuantWithMinMax ascend layer""" | |||||
| super(FakeQuantWithMinMaxD, self).__init__() | |||||
| """init FakeQuantWithMinMaxAscend layer""" | |||||
| super(FakeQuantWithMinMaxAscend, self).__init__() | |||||
| self.min_init = min_init | self.min_init = min_init | ||||
| self.num_bits = num_bits | |||||
| self.max_init = max_init | self.max_init = max_init | ||||
| self.num_bits = num_bits | |||||
| self.ema = ema | self.ema = ema | ||||
| self.ema_decay = ema_decay | self.ema_decay = ema_decay | ||||
| self.per_channel = per_channel | self.per_channel = per_channel | ||||
| self.channel_size = channel_size | |||||
| self.channel_axis = channel_axis | |||||
| self.quant_delay = quant_delay | self.quant_delay = quant_delay | ||||
| self.symmetric = symmetric | self.symmetric = symmetric | ||||
| self.narrow_range = narrow_range | self.narrow_range = narrow_range | ||||
| self.training = training | self.training = training | ||||
| if not per_channel: | |||||
| self.fake_quant = P.FakeQuantWithMinMax(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=self.ema_decay, | |||||
| quant_delay=self.quant_delay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=training) | |||||
| self.ema_update = P.FakeQuantWithMinMaxUpdate(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=self.ema_decay, | |||||
| quant_delay=self.quant_delay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=training) | |||||
| else: | |||||
| raise RuntimeError("not support per channel") | |||||
| # init tensor min and max for fake quant op | |||||
| if isinstance(min_init, int): | |||||
| min_array = np.array([min_init]).reshape(1).astype(np.float32) | |||||
| max_array = np.array([max_init]).reshape(1).astype(np.float32) | |||||
| elif isinstance(min_init, list): | |||||
| min_array = np.array([self.min_init for i in range( | |||||
| 0, self.out_channels)]).astype(np.float32) | |||||
| max_array = np.array([self.max_init for i in range( | |||||
| 0, self.out_channels)]).astype(np.float32) | |||||
| self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) | |||||
| self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) | |||||
| if isinstance(min_init, Parameter): | |||||
| self.minq = min_init | |||||
| self.maxq = max_init | |||||
| if per_channel: | |||||
| quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis) | |||||
| ema_fun = partial(P.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis) | |||||
| else: | else: | ||||
| self.minq = Parameter(Tensor(np.array([min_init]).astype(np.float32)), | |||||
| name='quant_min', | |||||
| requires_grad=False) | |||||
| self.maxq = Parameter(Tensor(np.array([max_init]).astype(np.float32)), | |||||
| name='quant_max', | |||||
| requires_grad=False) | |||||
| self.reduce_min = P.ReduceMin() | |||||
| self.reduce_max = P.ReduceMax() | |||||
| quant_fun = P.FakeQuantPerLayer | |||||
| ema_fun = P.FakeQuantMinMaxPerLayerUpdate | |||||
| self.fake_quant = quant_fun(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=self.ema_decay, | |||||
| quant_delay=self.quant_delay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=self.training) | |||||
| self.ema_update = ema_fun(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=self.ema_decay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=self.training) | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format( | |||||
| self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size, | |||||
| self.quant_delay) | |||||
| s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format( | |||||
| self.min_init, self.max_init, self.ema, self.ema_decay, | |||||
| self.per_channel, self.quant_delay, self.channel_axis) | |||||
| return s | return s | ||||
| def construct(self, x, minq, maxq): | |||||
| if self.training: | |||||
| min_up, max_up = self.ema_update(x, minq, maxq) | |||||
| def construct(self, x): | |||||
| if self.update: | |||||
| min_up, max_up = self.ema_update(x, self.minq, self.maxq) | |||||
| out = self.fake_quant(x, min_up, max_up) | out = self.fake_quant(x, min_up, max_up) | ||||
| P.Assign()(self.minq, min_up) | P.Assign()(self.minq, min_up) | ||||
| P.Assign()(self.maxq, max_up) | P.Assign()(self.maxq, max_up) | ||||
| else: | else: | ||||
| out = self.fake_quant(x, minq, maxq) | |||||
| out = self.fake_quant(x, self.minq, self.maxq) | |||||
| return out | return out | ||||
| class FakeQuantWithMinMax(Cell): | |||||
| class FakeQuantWithMinMaxGPU(Cell): | |||||
| r""" | r""" | ||||
| Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. | Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. | ||||
| @@ -240,98 +242,69 @@ class FakeQuantWithMinMax(Cell): | |||||
| ema=False, | ema=False, | ||||
| ema_decay=0.999, | ema_decay=0.999, | ||||
| per_channel=False, | per_channel=False, | ||||
| channel_axis=1, | |||||
| out_channels=1, | out_channels=1, | ||||
| quant_delay=0, | quant_delay=0, | ||||
| symmetric=False, | symmetric=False, | ||||
| narrow_range=False): | |||||
| """init FakeQuantWithMinMax layer""" | |||||
| super(FakeQuantWithMinMax, self).__init__() | |||||
| narrow_range=False, | |||||
| training=True): | |||||
| super(FakeQuantWithMinMaxGPU, self).__init__() | |||||
| self.min_init = min_init | self.min_init = min_init | ||||
| self.num_bits = num_bits | |||||
| self.max_init = max_init | self.max_init = max_init | ||||
| self.num_bits = num_bits | |||||
| self.ema = ema | self.ema = ema | ||||
| self.ema_decay = ema_decay | self.ema_decay = ema_decay | ||||
| self.per_channel = per_channel | self.per_channel = per_channel | ||||
| self.out_channels = out_channels | |||||
| self.channel_axis = channel_axis | |||||
| self.quant_delay = quant_delay | self.quant_delay = quant_delay | ||||
| self.symmetric = symmetric | self.symmetric = symmetric | ||||
| self.narrow_range = narrow_range | self.narrow_range = narrow_range | ||||
| self.training = training | |||||
| if per_channel: | |||||
| min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32) | |||||
| max_array = np.array([self.max_init for i in range(0, self.channel_size)]).astype(np.float32) | |||||
| self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) | |||||
| self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) | |||||
| self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=self.ema_decay, | |||||
| quant_delay=self.quant_delay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=True) | |||||
| self.fake_quant_infer = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=self.ema_decay, | |||||
| quant_delay=self.quant_delay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=False) | |||||
| else: | |||||
| # init tensor min and max for fake quant op | |||||
| if isinstance(min_init, int): | |||||
| min_array = np.array([min_init]).reshape(1).astype(np.float32) | min_array = np.array([min_init]).reshape(1).astype(np.float32) | ||||
| max_array = np.array([max_init]).reshape(1).astype(np.float32) | max_array = np.array([max_init]).reshape(1).astype(np.float32) | ||||
| self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) | |||||
| self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) | |||||
| if context.get_context('device_target') == "Ascend": | |||||
| self.fake_quant_train = FakeQuantWithMinMaxD(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=self.ema_decay, | |||||
| quant_delay=self.quant_delay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=True, | |||||
| min_init=self.minq, | |||||
| max_init=self.maxq) | |||||
| self.fake_quant_infer = FakeQuantWithMinMaxD(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=self.ema_decay, | |||||
| quant_delay=self.quant_delay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=False, | |||||
| min_init=self.minq, | |||||
| max_init=self.maxq) | |||||
| elif context.get_context('device_target') == "GPU": | |||||
| self.fake_quant_train = P.FakeQuantWithMinMax(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=self.ema_decay, | |||||
| quant_delay=self.quant_delay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=True) | |||||
| self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=ema_decay, | |||||
| quant_delay=quant_delay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=False) | |||||
| else: | |||||
| raise ValueError("Not support platform.") | |||||
| elif isinstance(min_init, list): | |||||
| min_array = np.array([self.min_init for i in range( | |||||
| 0, self.out_channels)]).astype(np.float32) | |||||
| max_array = np.array([self.max_init for i in range( | |||||
| 0, self.out_channels)]).astype(np.float32) | |||||
| self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) | |||||
| self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) | |||||
| if per_channel: | |||||
| quant_fun = partial(P.FakeQuantPerChannel, channel_axis=self.channel_axis) | |||||
| else: | |||||
| quant_fun = P.FakeQuantPerLayer | |||||
| self.fake_quant = quant_fun(num_bits=self.num_bits, | |||||
| ema=self.ema, | |||||
| ema_decay=ema_decay, | |||||
| quant_delay=quant_delay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.narrow_range, | |||||
| training=self.training) | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| s = 'min={}, max={}, ema={}, ema_decay={}, per_channel={}, quant_delay={}'.format( | |||||
| self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.quant_delay) | |||||
| s = 'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'.format( | |||||
| self.min_init, self.max_init, self.ema, self.ema_decay, | |||||
| self.per_channel, self.quant_delay, self.channel_axis) | |||||
| return s | return s | ||||
| def construct(self, x): | def construct(self, x): | ||||
| if self.training: | |||||
| out = self.fake_quant_train(x, self.minq, self.maxq) | |||||
| else: | |||||
| out = self.fake_quant_infer(x, self.minq, self.maxq) | |||||
| out = self.fake_quant(x, self.minq, self.maxq) | |||||
| return out | return out | ||||
| def FakeQuantWithMinMax(**kwargs): | |||||
| if context.get_context('device_target') == "Ascend": | |||||
| out = FakeQuantWithMinMaxAscend(**kwargs) | |||||
| if context.get_context('device_target') == "GPU": | |||||
| out = FakeQuantWithMinMaxGPU(**kwargs) | |||||
| else: | |||||
| raise ValueError("Not support platform or channel mode.") | |||||
| return out | |||||
| class Conv2dBatchNormQuant(Cell): | class Conv2dBatchNormQuant(Cell): | ||||
| r""" | r""" | ||||
| 2D convolution with BatchNormal op folded layer. | 2D convolution with BatchNormal op folded layer. | ||||
| @@ -420,7 +393,6 @@ class Conv2dBatchNormQuant(Cell): | |||||
| self.per_channel = per_channel | self.per_channel = per_channel | ||||
| self.symmetric = symmetric | self.symmetric = symmetric | ||||
| self.narrow_range = narrow_range | self.narrow_range = narrow_range | ||||
| self.channel_axis = int(group > 1) | |||||
| self.is_gpu = context.get_context('device_target') == "GPU" | self.is_gpu = context.get_context('device_target') == "GPU" | ||||
| # initialize convolution op and Parameter | # initialize convolution op and Parameter | ||||
| @@ -435,6 +407,7 @@ class Conv2dBatchNormQuant(Cell): | |||||
| dilation=self.dilation) | dilation=self.dilation) | ||||
| if weight_init is None: | if weight_init is None: | ||||
| weight_init = initializer('normal', [1, in_channels, *self.kernel_size]) | weight_init = initializer('normal', [1, in_channels, *self.kernel_size]) | ||||
| channel_axis = 1 | |||||
| else: | else: | ||||
| self.conv = P.Conv2D(out_channel=out_channels, | self.conv = P.Conv2D(out_channel=out_channels, | ||||
| kernel_size=self.kernel_size, | kernel_size=self.kernel_size, | ||||
| @@ -445,6 +418,7 @@ class Conv2dBatchNormQuant(Cell): | |||||
| group=group) | group=group) | ||||
| if weight_init is None: | if weight_init is None: | ||||
| weight_init = initializer('normal', [out_channels, in_channels // group, *self.kernel_size]) | weight_init = initializer('normal', [out_channels, in_channels // group, *self.kernel_size]) | ||||
| channel_axis = 0 | |||||
| self.weight = Parameter(weight_init, name='weight') | self.weight = Parameter(weight_init, name='weight') | ||||
| # initialize batchnorm Parameter | # initialize batchnorm Parameter | ||||
| @@ -472,7 +446,7 @@ class Conv2dBatchNormQuant(Cell): | |||||
| symmetric=symmetric, | symmetric=symmetric, | ||||
| narrow_range=narrow_range) | narrow_range=narrow_range) | ||||
| self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) | self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) | ||||
| self.correct_mul = P.CorrectionMul(self.channel_axis) | |||||
| self.correct_mul = P.CorrectionMul(channel_axis) | |||||
| if context.get_context('device_target') == "Ascend": | if context.get_context('device_target') == "Ascend": | ||||
| self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn) | self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn) | ||||
| self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0) | self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0) | ||||
| @@ -520,7 +494,7 @@ class Conv2dBatchNormQuant(Cell): | |||||
| out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std) | out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std) | ||||
| F.control_depend(out, self.assignadd(self.step, self.one)) | F.control_depend(out, self.assignadd(self.step, self.one)) | ||||
| else: | else: | ||||
| out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std) | |||||
| out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, running_std, running_mean, running_std) | |||||
| return out | return out | ||||
| @@ -20,10 +20,11 @@ from .grad_base import bprop_getters | |||||
| from ..composite.multitype_ops.zeros_like_impl import zeros_like | from ..composite.multitype_ops.zeros_like_impl import zeros_like | ||||
| @bprop_getters.register(P.FakeQuantWithMinMax) | |||||
| @bprop_getters.register(P.FakeQuantPerLayer) | |||||
| def get_bprop_fakequant_with_minmax(self): | def get_bprop_fakequant_with_minmax(self): | ||||
| """Generate bprop for FakeQuantWithMinMax for GPU and Ascend""" | |||||
| op = P.FakeQuantWithMinMaxGrad(num_bits=self.num_bits, quant_delay=self.quant_delay) | |||||
| """Generate bprop for FakeQuantPerLayer for GPU and Ascend""" | |||||
| op = P.FakeQuantPerLayerGrad( | |||||
| num_bits=self.num_bits, quant_delay=self.quant_delay) | |||||
| def bprop(x, x_min, x_max, out, dout): | def bprop(x, x_min, x_max, out, dout): | ||||
| dx = op(dout, x, x_min, x_max) | dx = op(dout, x, x_min, x_max) | ||||
| @@ -32,10 +33,14 @@ def get_bprop_fakequant_with_minmax(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.FakeQuantWithMinMaxPerChannel) | |||||
| @bprop_getters.register(P.FakeQuantPerChannel) | |||||
| def get_bprop_fakequant_with_minmax_perchannel(self): | def get_bprop_fakequant_with_minmax_perchannel(self): | ||||
| """Generate bprop for FakeQuantWithMinMaxPerChannel for GPU""" | |||||
| op = P.FakeQuantWithMinMaxPerChannelGrad(num_bits=self.num_bits, quant_delay=self.quant_delay) | |||||
| """Generate bprop for FakeQuantPerChannel""" | |||||
| op = P.FakeQuantPerChannelGrad(num_bits=self.num_bits, | |||||
| quant_delay=self.quant_delay, | |||||
| symmetric=self.symmetric, | |||||
| narrow_range=self.symmetric, | |||||
| channel_axis=self.channel_axis) | |||||
| def bprop(x, x_min, x_max, out, dout): | def bprop(x, x_min, x_max, out, dout): | ||||
| dx = op(dout, x, x_min, x_max) | dx = op(dout, x, x_min, x_max) | ||||
| @@ -77,7 +82,7 @@ def get_bprop_batchnorm_fold2(self): | |||||
| d_batch_std, d_batch_mean, d_beta, d_gamma, d_x = op_f(dout, x, gamma, batch_std, batch_mean, running_std, | d_batch_std, d_batch_mean, d_beta, d_gamma, d_x = op_f(dout, x, gamma, batch_std, batch_mean, running_std, | ||||
| running_mean, global_step) | running_mean, global_step) | ||||
| return d_x, d_beta, d_gamma, d_batch_std, d_batch_mean, zeros_like(running_std), zeros_like(running_mean), \ | return d_x, d_beta, d_gamma, d_batch_std, d_batch_mean, zeros_like(running_std), zeros_like(running_mean), \ | ||||
| zeros_like(global_step) | |||||
| zeros_like(global_step) | |||||
| return bprop | return bprop | ||||
| @@ -117,9 +122,19 @@ def get_bprop_batchnorm_fold2_(self): | |||||
| return bprop | return bprop | ||||
| @bprop_getters.register(P.FakeQuantWithMinMaxUpdate) | |||||
| def get_bprop_fakequant_with_minmax_update(self): | |||||
| """Generate bprop for FakeQuantWithMinMaxUpdate for Ascend""" | |||||
| @bprop_getters.register(P.FakeQuantMinMaxPerLayerUpdate) | |||||
| def get_bprop_fakequant_with_minmax_per_layer_update(self): | |||||
| """Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend""" | |||||
| def bprop(x, x_min, x_max, out, dout): | |||||
| return zeros_like(x), zeros_like(x_min), zeros_like(x_max) | |||||
| return bprop | |||||
| @bprop_getters.register(P.FakeQuantMinMaxPerChannelUpdate) | |||||
| def get_bprop_fakequant_with_minmax_per_channel_update(self): | |||||
| """Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend""" | |||||
| def bprop(x, x_min, x_max, out, dout): | def bprop(x, x_min, x_max, out, dout): | ||||
| return zeros_like(x), zeros_like(x_min), zeros_like(x_max) | return zeros_like(x), zeros_like(x_min), zeros_like(x_max) | ||||
| @@ -0,0 +1,135 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """FakeQuantMinMaxPerChannelUpdate op""" | |||||
| import te.lang.cce | |||||
| from te import tvm | |||||
| from te.platform.fusion_manager import fusion_manager | |||||
| from topi import generic | |||||
| from topi.cce import util | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("fake_quant_min_max_per_channel_update.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("fake_quant_min_max_per_channel_update") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("ema", "optional", "bool", "all") \ | |||||
| .attr("ema_decay", "optional", "float", "all") \ | |||||
| .attr("symmetric", "optional", "bool", "all") \ | |||||
| .attr("narrow_range", "optional", "bool", "all") \ | |||||
| .attr("training", "optional", "bool", "all") \ | |||||
| .attr("num_bits", "optional", "int", "all") \ | |||||
| .attr("channel_axis", "optional", "int", "all") \ | |||||
| .input(0, "x", None, "required", None) \ | |||||
| .input(1, "min", None, "required", None) \ | |||||
| .input(2, "max", None, "required", None) \ | |||||
| .output(0, "min_up", True, "required", "all") \ | |||||
| .output(1, "max_up", True, "required", "all") \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||||
| DataType.F32_5HD) \ | |||||
| .get_op_info() | |||||
| @op_info_register(fake_quant_min_max_per_channel_update_op_info) | |||||
| def _fake_quant_min_max_per_channel_update_tbe(): | |||||
| """FakeQuantPerChannelUpdate TBE register""" | |||||
| return | |||||
| @fusion_manager.register("fake_quant_min_max_per_channel_update") | |||||
| def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val, | |||||
| ema, ema_decay, quant_min, quant_max, training, channel_axis, | |||||
| kernel_name="fake_quant_min_max_per_channel_update"): | |||||
| """FakeQuantPerChannelUpdate compute""" | |||||
| shape_min = te.lang.cce.util.shape_to_list(min_val.shape) | |||||
| if not ema: | |||||
| ema_decay = 0.0 | |||||
| if training: | |||||
| # CalMinMax | |||||
| axis = [0, 2, 3] | |||||
| x_min = te.lang.cce.reduce_min(x, axis=axis) | |||||
| x_max = te.lang.cce.reduce_max(x, axis=axis) | |||||
| x_min = te.lang.cce.broadcast(x_min, shape_min) | |||||
| x_max = te.lang.cce.broadcast(x_max, shape_min) | |||||
| min_val = te.lang.cce.vadd(te.lang.cce.vmuls( | |||||
| min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) | |||||
| max_val = te.lang.cce.vadd(te.lang.cce.vmuls( | |||||
| max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) | |||||
| min_val = te.lang.cce.vmins(min_val, 0) | |||||
| max_val = te.lang.cce.vmaxs(max_val, 0) | |||||
| return [min_val, max_val] | |||||
| @util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str) | |||||
| def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up, | |||||
| ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis, | |||||
| kernel_name="fake_quant_min_max_per_channel_update"): | |||||
| """FakeQuantPerLayer op""" | |||||
| x_shape = x.get("ori_shape") | |||||
| x_format = x.get("format") | |||||
| x_dtype = x.get("dtype") | |||||
| min_shape = min_val.get("ori_shape") | |||||
| min_dtype = min_val.get("dtype") | |||||
| max_shape = max_val.get("ori_shape") | |||||
| max_dtype = max_val.get("dtype") | |||||
| util.check_kernel_name(kernel_name) | |||||
| util.check_shape_rule(x_shape) | |||||
| util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis]) | |||||
| util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis]) | |||||
| util.check_tensor_shape_size(x_shape) | |||||
| util.check_tensor_shape_size(min_shape) | |||||
| util.check_tensor_shape_size(max_shape) | |||||
| check_list = ["float32", "float16"] | |||||
| x_dtype = x_dtype.lower() | |||||
| min_dtype = min_dtype.lower() | |||||
| max_dtype = max_dtype.lower() | |||||
| util.check_dtype_rule(x_dtype, check_list) | |||||
| util.check_dtype_rule(min_dtype, check_list) | |||||
| util.check_dtype_rule(max_dtype, check_list) | |||||
| if symmetric: | |||||
| quant_min = 0 - 2 ** (num_bits - 1) | |||||
| quant_max = 2 ** (num_bits - 1) - 1 | |||||
| else: | |||||
| quant_min = 0 | |||||
| quant_max = 2 ** num_bits - 1 | |||||
| if narrow_range: | |||||
| quant_min = quant_min + 1 | |||||
| shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]] | |||||
| input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype) | |||||
| min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) | |||||
| max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) | |||||
| res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data, | |||||
| ema, ema_decay, quant_min, quant_max, training, channel_axis, kernel_name) | |||||
| with tvm.target.cce(): | |||||
| sch = generic.auto_schedule(res_list) | |||||
| tensor_list = [input_data, min_data, max_data] + list(res_list) | |||||
| config = {"print_ir": False, | |||||
| "name": kernel_name, | |||||
| "tensor_list": tensor_list} | |||||
| te.lang.cce.cce_build_code(sch, config) | |||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """FakeQuantWithMinMaxUpdate op""" | |||||
| """FakeQuantMinMaxPerLayerUpdate op""" | |||||
| from functools import reduce as functools_reduce | from functools import reduce as functools_reduce | ||||
| import te.lang.cce | import te.lang.cce | ||||
| from te import tvm | from te import tvm | ||||
| @@ -23,12 +23,12 @@ from topi.cce import util | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | ||||
| fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ | |||||
| fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .async_flag(False) \ | .async_flag(False) \ | ||||
| .binfile_name("fake_quant_with_min_max_update.so") \ | |||||
| .binfile_name("fake_quant_minmax_update.so") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("fake_quant_with_min_max_update") \ | |||||
| .kernel_name("fake_quant_minmax_update") \ | |||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .attr("ema", "optional", "bool", "all") \ | .attr("ema", "optional", "bool", "all") \ | ||||
| .attr("ema_decay", "optional", "float", "all") \ | .attr("ema_decay", "optional", "float", "all") \ | ||||
| @@ -36,7 +36,6 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ | |||||
| .attr("narrow_range", "optional", "bool", "all") \ | .attr("narrow_range", "optional", "bool", "all") \ | ||||
| .attr("training", "optional", "bool", "all") \ | .attr("training", "optional", "bool", "all") \ | ||||
| .attr("num_bits", "optional", "int", "all") \ | .attr("num_bits", "optional", "int", "all") \ | ||||
| .attr("quant_delay", "optional", "int", "all") \ | |||||
| .input(0, "x", None, "required", None) \ | .input(0, "x", None, "required", None) \ | ||||
| .input(1, "min", None, "required", None) \ | .input(1, "min", None, "required", None) \ | ||||
| .input(2, "max", None, "required", None) \ | .input(2, "max", None, "required", None) \ | ||||
| @@ -47,16 +46,16 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @op_info_register(fake_quant_update_op_info) | |||||
| def _fake_quant_update_tbe(): | |||||
| """_FakeQuantWithMinMaxUpdate TBE register""" | |||||
| @op_info_register(fake_quant_minmax_update_op_info) | |||||
| def _fake_quant_minmax_update_tbe(): | |||||
| """FakeQuantMinMaxPerLayerUpdate TBE register""" | |||||
| return | return | ||||
| @fusion_manager.register("fake_quant_with_min_max_update") | |||||
| def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training, | |||||
| kernel_name="fake_quant_update"): | |||||
| """FakeQuantWithMinMaxUpdate compute""" | |||||
| @fusion_manager.register("fake_quant_minmax_update") | |||||
| def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training, | |||||
| kernel_name="fake_quant_minmax_update"): | |||||
| """FakeQuantMinMaxPerLayerUpdate compute""" | |||||
| shape = te.lang.cce.util.shape_to_list(x.shape) | shape = te.lang.cce.util.shape_to_list(x.shape) | ||||
| shape_min = te.lang.cce.util.shape_to_list(min_val.shape) | shape_min = te.lang.cce.util.shape_to_list(min_val.shape) | ||||
| min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) | min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) | ||||
| @@ -70,19 +69,21 @@ def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay, | |||||
| x_max = te.lang.cce.reduce_max(x, axis=axis) | x_max = te.lang.cce.reduce_max(x, axis=axis) | ||||
| x_min = te.lang.cce.broadcast(x_min, shape_min) | x_min = te.lang.cce.broadcast(x_min, shape_min) | ||||
| x_max = te.lang.cce.broadcast(x_max, shape_min) | x_max = te.lang.cce.broadcast(x_max, shape_min) | ||||
| min_val = te.lang.cce.vadd(te.lang.cce.vmuls(min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) | |||||
| max_val = te.lang.cce.vadd(te.lang.cce.vmuls(max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) | |||||
| min_val = te.lang.cce.vadd(te.lang.cce.vmuls( | |||||
| min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) | |||||
| max_val = te.lang.cce.vadd(te.lang.cce.vmuls( | |||||
| max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) | |||||
| min_val = te.lang.cce.vmins(min_val, 0) | min_val = te.lang.cce.vmins(min_val, 0) | ||||
| max_val = te.lang.cce.vmaxs(max_val, 0) | max_val = te.lang.cce.vmaxs(max_val, 0) | ||||
| return [min_val, max_val] | return [min_val, max_val] | ||||
| @util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str) | |||||
| def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up, | |||||
| ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay, | |||||
| kernel_name="fake_quant_update"): | |||||
| """FakeQuantWithMinMax op""" | |||||
| @util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str) | |||||
| def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up, | |||||
| ema, ema_decay, symmetric, narrow_range, training, num_bits, | |||||
| kernel_name="fake_quant_minmax_update"): | |||||
| """FakeQuantPerLayer op""" | |||||
| input_shape = x.get("shape") | input_shape = x.get("shape") | ||||
| input_dtype = x.get("dtype") | input_dtype = x.get("dtype") | ||||
| min_shape = min_val.get("ori_shape") | min_shape = min_val.get("ori_shape") | ||||
| @@ -123,8 +124,8 @@ def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up, | |||||
| input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) | input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) | ||||
| min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) | min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) | ||||
| max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) | max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) | ||||
| res_list = fake_quant_with_min_max_update_compute(input_data, min_data, max_data, | |||||
| ema, ema_decay, quant_min, quant_max, training, kernel_name) | |||||
| res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data, | |||||
| ema, ema_decay, quant_min, quant_max, training, kernel_name) | |||||
| with tvm.target.cce(): | with tvm.target.cce(): | ||||
| sch = generic.auto_schedule(res_list) | sch = generic.auto_schedule(res_list) | ||||
| @@ -0,0 +1,145 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """FakeQuantPerChannel op""" | |||||
| import te.lang.cce | |||||
| from te import tvm | |||||
| from te.platform.fusion_manager import fusion_manager | |||||
| from topi import generic | |||||
| from topi.cce import util | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| fake_quant_perchannel_op_info = TBERegOp("FakeQuantPerChannel") \ | |||||
| .fusion_type("ELEMWISE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("fake_quant_perchannel.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("fake_quant_perchannel") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("symmetric", "optional", "bool", "all") \ | |||||
| .attr("narrow_range", "optional", "bool", "all") \ | |||||
| .attr("num_bits", "optional", "int", "all") \ | |||||
| .attr("channel_axis", "optional", "int", "all") \ | |||||
| .input(0, "x", None, "required", None) \ | |||||
| .input(1, "min", None, "required", None) \ | |||||
| .input(2, "max", None, "required", None) \ | |||||
| .output(0, "y", True, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .get_op_info() | |||||
| @op_info_register(fake_quant_perchannel_op_info) | |||||
| def _fake_quant_perchannel_tbe(): | |||||
| """FakeQuantPerChannel TBE register""" | |||||
| return | |||||
| @fusion_manager.register("fake_quant_perchannel") | |||||
| def fake_quant_perchannel_compute(x, min_val, max_val, y, quant_min, quant_max, | |||||
| kernel_name="fake_quant_perchannel"): | |||||
| """FakeQuantPerChannel""" | |||||
| x_shape = te.lang.cce.util.shape_to_list(x.shape) | |||||
| minmax_shape = te.lang.cce.util.shape_to_list(min_val.shape) | |||||
| quant_min = tvm.const(quant_min, x.dtype) | |||||
| quant_max = tvm.const(quant_max, x.dtype) | |||||
| quant_min = te.lang.cce.broadcast(quant_min, minmax_shape, x.dtype) | |||||
| quant_max = te.lang.cce.broadcast(quant_max, minmax_shape, x.dtype) | |||||
| # CalNudge(NudgeMinMax) | |||||
| scale = te.lang.cce.vdiv(te.lang.cce.vsub( | |||||
| max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) | |||||
| zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale)) | |||||
| # Nudge zero point | |||||
| nudge_zp_ = te.lang.cce.vmin( | |||||
| quant_max, te.lang.cce.vmax(quant_min, zp_from_min)) | |||||
| nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5)) | |||||
| nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale) | |||||
| nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale) | |||||
| # FakeQuant | |||||
| nudge_min_b = te.lang.cce.broadcast(nudge_min, x_shape) | |||||
| nudge_max_b = te.lang.cce.broadcast(nudge_max, x_shape) | |||||
| scale_b = te.lang.cce.broadcast(scale, x_shape) | |||||
| input_x = te.lang.cce.vmin(nudge_max_b, te.lang.cce.vmax(nudge_min_b, x)) | |||||
| nudge_input_ = te.lang.cce.vdiv( | |||||
| te.lang.cce.vsub(input_x, nudge_min_b), scale_b) | |||||
| nudge_input = te.lang.cce.floor(te.lang.cce.vadds(nudge_input_, 0.5)) | |||||
| res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale_b), nudge_min_b) | |||||
| return res | |||||
| @util.check_input_type(dict, dict, dict, dict, bool, bool, int, int, str) | |||||
| def fake_quant_perchannel(x, min_val, max_val, y, | |||||
| symmetric, narrow_range, num_bits, channel_axis, | |||||
| kernel_name="fake_quant_perchannel"): | |||||
| """FakeQuantPerChannel""" | |||||
| x_shape = x.get("shape") | |||||
| x_format = x.get("format") | |||||
| x_dtype = x.get("dtype") | |||||
| min_shape = min_val.get("ori_shape") | |||||
| min_dtype = min_val.get("dtype") | |||||
| max_shape = max_val.get("ori_shape") | |||||
| max_dtype = max_val.get("dtype") | |||||
| util.check_kernel_name(kernel_name) | |||||
| util.check_shape_rule(x_shape) | |||||
| util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis]) | |||||
| util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis]) | |||||
| util.check_tensor_shape_size(x_shape) | |||||
| util.check_tensor_shape_size(min_shape) | |||||
| util.check_tensor_shape_size(max_shape) | |||||
| check_list = ["float32", "float16"] | |||||
| x_dtype = x_dtype.lower() | |||||
| min_dtype = min_dtype.lower() | |||||
| max_dtype = max_dtype.lower() | |||||
| util.check_dtype_rule(x_dtype, check_list) | |||||
| util.check_dtype_rule(min_dtype, check_list) | |||||
| util.check_dtype_rule(max_dtype, check_list) | |||||
| if symmetric: | |||||
| quant_min = 0 - 2 ** (num_bits - 1) | |||||
| quant_max = 2 ** (num_bits - 1) - 1 | |||||
| else: | |||||
| quant_min = 0 | |||||
| quant_max = 2 ** num_bits - 1 | |||||
| if narrow_range: | |||||
| quant_min = quant_min + 1 | |||||
| shape_c = [1] * len(x_shape) | |||||
| shape_c[channel_axis] = min_val.get("ori_shape")[0] | |||||
| if x_format == "NC1HWC0" and channel_axis == 1: | |||||
| shape_c = min_val.get("shape") | |||||
| input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype) | |||||
| min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) | |||||
| max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) | |||||
| res = fake_quant_perchannel_compute(input_data, min_data, max_data, y, | |||||
| quant_min, quant_max, kernel_name) | |||||
| with tvm.target.cce(): | |||||
| sch = generic.auto_schedule(res) | |||||
| tensor_list = [input_data, min_data, max_data, res] | |||||
| config = {"print_ir": False, | |||||
| "name": kernel_name, | |||||
| "tensor_list": tensor_list} | |||||
| te.lang.cce.cce_build_code(sch, config) | |||||
| @@ -0,0 +1,171 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """FakeQuantPerChannelGrad op""" | |||||
| import te.lang.cce | |||||
| from te import tvm | |||||
| from te.platform.fusion_manager import fusion_manager | |||||
| from topi import generic | |||||
| from topi.cce import util | |||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| SHAPE_SIZE_LIMIT = 2147483648 | |||||
| D_TYPE = 'float32' | |||||
| fake_quant_perchannel_grad_op_info = TBERegOp("FakeQuantPerChannelGrad") \ | |||||
| .fusion_type("OPAQUE") \ | |||||
| .async_flag(False) \ | |||||
| .binfile_name("fake_quant_perchannel_grad.so") \ | |||||
| .compute_cost(10) \ | |||||
| .kernel_name("fake_quant_perchannel_grad") \ | |||||
| .partial_flag(True) \ | |||||
| .attr("symmetric", "optional", "bool", "all") \ | |||||
| .attr("narrow_range", "optional", "bool", "all") \ | |||||
| .attr("num_bits", "optional", "int", "all") \ | |||||
| .attr("channel_axis", "optional", "int", "all") \ | |||||
| .input(0, "dout", None, "required", None) \ | |||||
| .input(1, "x", None, "required", None) \ | |||||
| .input(2, "min", None, "required", None) \ | |||||
| .input(3, "max", None, "required", None) \ | |||||
| .output(0, "dx", True, "required", "all") \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||||
| DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .get_op_info() | |||||
| def _less_compare_float32(data_x, data_y): | |||||
| """_less_compare_float32 compute""" | |||||
| input_shape = te.lang.cce.util.shape_to_list(data_x.shape) | |||||
| min_value = tvm.const(2 ** (-126), dtype=D_TYPE) | |||||
| max_value = tvm.const(2 ** 62, dtype=D_TYPE) | |||||
| factor_value = tvm.const(2 ** 2, dtype=D_TYPE) | |||||
| data_zero = te.lang.cce.broadcast( | |||||
| tvm.const(0, dtype=D_TYPE), input_shape, D_TYPE) | |||||
| min_value_tensor = te.lang.cce.vadds(data_zero, min_value) | |||||
| res_sub = te.lang.cce.vsub(data_y, data_x) | |||||
| res_min = te.lang.cce.vmin(res_sub, min_value_tensor) | |||||
| res_max = te.lang.cce.vmax(res_min, data_zero) | |||||
| res_max_mul = te.lang.cce.vmuls(res_max, max_value) | |||||
| res_max_mul_max = te.lang.cce.vmuls(res_max_mul, max_value) | |||||
| res = te.lang.cce.vmuls(res_max_mul_max, factor_value) | |||||
| return res | |||||
| @op_info_register(fake_quant_perchannel_grad_op_info) | |||||
| def _fake_quant_perchannel_grad_tbe(): | |||||
| """FakeQuantPerChannelGrad TBE register""" | |||||
| return | |||||
| @fusion_manager.register("fake_quant_perchannel_grad") | |||||
| def fake_quant_perchannel_grad_compute(dout, x, min_val, max_val, quant_min, quant_max, | |||||
| kernel_name="fake_quant_perchannel_grad"): | |||||
| """FakeQuantPerChannelGrad""" | |||||
| x_shape = te.lang.cce.util.shape_to_list(x.shape) | |||||
| minmax_shape = te.lang.cce.util.shape_to_list(min_val.shape) | |||||
| quant_min = tvm.const(quant_min, x.dtype) | |||||
| quant_max = tvm.const(quant_max, x.dtype) | |||||
| quant_min = te.lang.cce.broadcast(quant_min, minmax_shape, x.dtype) | |||||
| quant_max = te.lang.cce.broadcast(quant_max, minmax_shape, x.dtype) | |||||
| # CalNudge(NudgeMinMax) | |||||
| scale = te.lang.cce.vdiv(te.lang.cce.vsub( | |||||
| max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) | |||||
| zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale)) | |||||
| # Nudge zero point | |||||
| nudge_zp_ = te.lang.cce.vmin( | |||||
| quant_max, te.lang.cce.vmax(quant_min, zp_from_min)) | |||||
| nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5)) | |||||
| nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale) | |||||
| nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale) | |||||
| # FakeQuant Grad | |||||
| nudge_min_b = te.lang.cce.broadcast(nudge_min, x_shape) | |||||
| nudge_max_b = te.lang.cce.broadcast(nudge_max, x_shape) | |||||
| bool_over_min = _less_compare_float32(nudge_min_b, x) | |||||
| bool_less_max = _less_compare_float32(x, nudge_max_b) | |||||
| bool_between = te.lang.cce.vmul(bool_over_min, bool_less_max) | |||||
| res = te.lang.cce.vmul(dout, bool_between) | |||||
| return res | |||||
| @util.check_input_type(dict, dict, dict, dict, dict, bool, bool, int, int, str) | |||||
| def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx, | |||||
| symmetric, narrow_range, num_bits, channel_axis, | |||||
| kernel_name="fake_quant_perchannel_grad"): | |||||
| """FakeQuantPerChannelGrad""" | |||||
| x_shape = x.get("shape") | |||||
| x_format = x.get("format") | |||||
| x_dtype = x.get("dtype") | |||||
| min_shape = min_val.get("ori_shape") | |||||
| min_dtype = min_val.get("dtype") | |||||
| max_shape = max_val.get("ori_shape") | |||||
| max_dtype = max_val.get("dtype") | |||||
| util.check_kernel_name(kernel_name) | |||||
| util.check_shape_rule(x_shape) | |||||
| util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis]) | |||||
| util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis]) | |||||
| util.check_tensor_shape_size(x_shape) | |||||
| util.check_tensor_shape_size(min_shape) | |||||
| util.check_tensor_shape_size(max_shape) | |||||
| check_list = ["float32", "float16"] | |||||
| x_dtype = x_dtype.lower() | |||||
| min_dtype = min_dtype.lower() | |||||
| max_dtype = max_dtype.lower() | |||||
| util.check_dtype_rule(x_dtype, check_list) | |||||
| util.check_dtype_rule(min_dtype, check_list) | |||||
| util.check_dtype_rule(max_dtype, check_list) | |||||
| if symmetric: | |||||
| quant_min = 0 - 2 ** (num_bits - 1) | |||||
| quant_max = 2 ** (num_bits - 1) - 1 | |||||
| else: | |||||
| quant_min = 0 | |||||
| quant_max = 2 ** num_bits - 1 | |||||
| if narrow_range: | |||||
| quant_min = quant_min + 1 | |||||
| shape_c = [1] * len(x_shape) | |||||
| shape_c[channel_axis] = min_val.get("ori_shape")[0] | |||||
| if x_format == "NC1HWC0" and channel_axis == 1: | |||||
| shape_c = min_val.get("shape") | |||||
| dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype) | |||||
| input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype) | |||||
| min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) | |||||
| max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) | |||||
| res = fake_quant_perchannel_grad_compute(dout_data, input_data, min_data, max_data, | |||||
| quant_min, quant_max, kernel_name) | |||||
| with tvm.target.cce(): | |||||
| sch = generic.auto_schedule(res) | |||||
| tensor_list = [dout_data, input_data, min_data, max_data, res] | |||||
| config = {"print_ir": False, | |||||
| "name": kernel_name, | |||||
| "tensor_list": tensor_list} | |||||
| te.lang.cce.cce_build_code(sch, config) | |||||
| @@ -13,8 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """FakeQuantWithMinMax op""" | |||||
| """FakeQuantPerLayer op""" | |||||
| from functools import reduce as functools_reduce | from functools import reduce as functools_reduce | ||||
| import te.lang.cce | import te.lang.cce | ||||
| from te import tvm | from te import tvm | ||||
| @@ -23,20 +22,16 @@ from topi import generic | |||||
| from topi.cce import util | from topi.cce import util | ||||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | ||||
| fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \ | |||||
| fake_quant_per_layer_op_info = TBERegOp("FakeQuantPerLayer") \ | |||||
| .fusion_type("ELEMWISE") \ | .fusion_type("ELEMWISE") \ | ||||
| .async_flag(False) \ | .async_flag(False) \ | ||||
| .binfile_name("fake_quant_with_min_max_vars_ema.so") \ | |||||
| .binfile_name("fake_quant_per_layer.so") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("fake_quant_with_min_max_vars_ema") \ | |||||
| .kernel_name("fake_quant_per_layer") \ | |||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .attr("ema", "optional", "bool", "all") \ | |||||
| .attr("ema_decay", "optional", "float", "all") \ | |||||
| .attr("symmetric", "optional", "bool", "all") \ | .attr("symmetric", "optional", "bool", "all") \ | ||||
| .attr("narrow_range", "optional", "bool", "all") \ | .attr("narrow_range", "optional", "bool", "all") \ | ||||
| .attr("training", "optional", "bool", "all") \ | |||||
| .attr("num_bits", "optional", "int", "all") \ | .attr("num_bits", "optional", "int", "all") \ | ||||
| .attr("quant_delay", "optional", "int", "all") \ | |||||
| .input(0, "x", None, "required", None) \ | .input(0, "x", None, "required", None) \ | ||||
| .input(1, "min", None, "required", None) \ | .input(1, "min", None, "required", None) \ | ||||
| .input(2, "max", None, "required", None) \ | .input(2, "max", None, "required", None) \ | ||||
| @@ -49,15 +44,15 @@ fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \ | |||||
| @op_info_register(fake_quant_op_info) | @op_info_register(fake_quant_op_info) | ||||
| def _fake_quant_tbe(): | |||||
| """FakeQuantWithMinMax TBE register""" | |||||
| def _fake_quant_per_layer_tbe(): | |||||
| """FakeQuantPerLayer TBE register""" | |||||
| return | return | ||||
| @fusion_manager.register("fake_quant_with_min_max_vars_ema") | |||||
| def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, quant_max, | |||||
| kernel_name="correction_mul"): | |||||
| """FakeQuantWithMinMax""" | |||||
| @fusion_manager.register("fake_quant_per_layer") | |||||
| def fake_quant_per_layer_compute(x, min_val, max_val, y, quant_min, quant_max, | |||||
| kernel_name="fake_quant_per_layer"): | |||||
| """FakeQuantPerLayer""" | |||||
| shape = te.lang.cce.util.shape_to_list(x.shape) | shape = te.lang.cce.util.shape_to_list(x.shape) | ||||
| shape_min = te.lang.cce.util.shape_to_list(min_val.shape) | shape_min = te.lang.cce.util.shape_to_list(min_val.shape) | ||||
| quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype) | quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype) | ||||
| @@ -66,10 +61,13 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, | |||||
| max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype) | max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype) | ||||
| # CalNudge(NudgeMinMax) | # CalNudge(NudgeMinMax) | ||||
| scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) | |||||
| scale = te.lang.cce.vdiv(te.lang.cce.vsub( | |||||
| max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) | |||||
| zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale)) | zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale)) | ||||
| # Nudge zero point | # Nudge zero point | ||||
| nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min))) | |||||
| nudge_zp_ = te.lang.cce.vmin( | |||||
| quant_max, te.lang.cce.vmax(quant_min, zp_from_min)) | |||||
| nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5)) | |||||
| nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale) | nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale) | ||||
| nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale) | nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale) | ||||
| @@ -80,17 +78,19 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, | |||||
| # FakeQuant | # FakeQuant | ||||
| input_x = te.lang.cce.vmin(nudge_max, te.lang.cce.vmax(nudge_min, x)) | input_x = te.lang.cce.vmin(nudge_max, te.lang.cce.vmax(nudge_min, x)) | ||||
| nudge_input = te.lang.cce.round(te.lang.cce.vdiv(te.lang.cce.vsub(input_x, nudge_min), scale)) | |||||
| nudge_input_ = te.lang.cce.vdiv( | |||||
| te.lang.cce.vsub(input_x, nudge_min), scale) | |||||
| nudge_input = te.lang.cce.floor(te.lang.cce.vadds(nudge_input_, 0.5)) | |||||
| res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale), nudge_min) | res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale), nudge_min) | ||||
| return res | return res | ||||
| @util.check_input_type(dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str) | |||||
| def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y, | |||||
| ema, ema_decay, symmetric, narrow_range, training, num_bits, quant_delay, | |||||
| kernel_name="fake_quant"): | |||||
| """FakeQuantWithMinMax""" | |||||
| @util.check_input_type(dict, dict, dict, dict, bool, bool, int, str) | |||||
| def fake_quant_per_layer(x, min_val, max_val, y, | |||||
| symmetric, narrow_range, num_bits, | |||||
| kernel_name="fake_quant_per_layer"): | |||||
| """FakeQuantPerLayer""" | |||||
| input_shape = x.get("shape") | input_shape = x.get("shape") | ||||
| input_dtype = x.get("dtype") | input_dtype = x.get("dtype") | ||||
| min_shape = min_val.get("ori_shape") | min_shape = min_val.get("ori_shape") | ||||
| @@ -131,8 +131,8 @@ def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y, | |||||
| input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) | input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) | ||||
| min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) | min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) | ||||
| max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) | max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) | ||||
| res = fake_quant_with_min_max_vars_ema_compute(input_data, min_data, max_data, y, | |||||
| quant_min, quant_max, kernel_name) | |||||
| res = fake_quant_per_layer_compute(input_data, min_data, max_data, y, | |||||
| quant_min, quant_max, kernel_name) | |||||
| with tvm.target.cce(): | with tvm.target.cce(): | ||||
| sch = generic.auto_schedule(res) | sch = generic.auto_schedule(res) | ||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """FakeQuantWithMinMaxGrad op""" | |||||
| """FakeQuantPerLayerGrad op""" | |||||
| from functools import reduce as functools_reduce | from functools import reduce as functools_reduce | ||||
| import te.lang.cce | import te.lang.cce | ||||
| @@ -26,15 +26,14 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||||
| SHAPE_SIZE_LIMIT = 2147483648 | SHAPE_SIZE_LIMIT = 2147483648 | ||||
| D_TYPE = 'float32' | D_TYPE = 'float32' | ||||
| fake_quant_grad_op_info = TBERegOp("FakeQuantWithMinMaxGrad") \ | |||||
| fake_quant_per_layer_grad_op_info = TBERegOp("FakeQuantPerLayerGrad") \ | |||||
| .fusion_type("OPAQUE") \ | .fusion_type("OPAQUE") \ | ||||
| .async_flag(False) \ | .async_flag(False) \ | ||||
| .binfile_name("fake_quant_with_min_max_grad.so") \ | |||||
| .binfile_name("fake_quant_per_layer_grad.so") \ | |||||
| .compute_cost(10) \ | .compute_cost(10) \ | ||||
| .kernel_name("fake_quant_with_min_max_grad") \ | |||||
| .kernel_name("fake_quant_per_layer_grad") \ | |||||
| .partial_flag(True) \ | .partial_flag(True) \ | ||||
| .attr("num_bits", "optional", "int", "all") \ | .attr("num_bits", "optional", "int", "all") \ | ||||
| .attr("quant_delay", "optional", "int", "all") \ | |||||
| .attr("symmetric", "optional", "bool", "all") \ | .attr("symmetric", "optional", "bool", "all") \ | ||||
| .attr("narrow_range", "optional", "bool", "all") \ | .attr("narrow_range", "optional", "bool", "all") \ | ||||
| .input(0, "dout", None, "required", None) \ | .input(0, "dout", None, "required", None) \ | ||||
| @@ -57,7 +56,8 @@ def _less_compare_float32(data_x, data_y): | |||||
| min_value = tvm.const(2 ** (-126), dtype=D_TYPE) | min_value = tvm.const(2 ** (-126), dtype=D_TYPE) | ||||
| max_value = tvm.const(2 ** 62, dtype=D_TYPE) | max_value = tvm.const(2 ** 62, dtype=D_TYPE) | ||||
| factor_value = tvm.const(2 ** 2, dtype=D_TYPE) | factor_value = tvm.const(2 ** 2, dtype=D_TYPE) | ||||
| data_zero = te.lang.cce.broadcast(tvm.const(0, dtype=D_TYPE), shape_inputs, D_TYPE) | |||||
| data_zero = te.lang.cce.broadcast( | |||||
| tvm.const(0, dtype=D_TYPE), shape_inputs, D_TYPE) | |||||
| min_value_tensor = te.lang.cce.vadds(data_zero, min_value) | min_value_tensor = te.lang.cce.vadds(data_zero, min_value) | ||||
| res_sub = te.lang.cce.vsub(data_y, data_x) | res_sub = te.lang.cce.vsub(data_y, data_x) | ||||
| @@ -71,16 +71,16 @@ def _less_compare_float32(data_x, data_y): | |||||
| return res | return res | ||||
| @op_info_register(fake_quant_grad_op_info) | |||||
| def _fake_quant_grad_tbe(): | |||||
| """FakeQuantWithMinMaxGrad TBE register""" | |||||
| @op_info_register(fake_quant_per_layer_grad_op_info) | |||||
| def _fake_quant_per_layer_grad_tbe(): | |||||
| """FakeQuantPerLayerGrad TBE register""" | |||||
| return | return | ||||
| @fusion_manager.register("fake_quant_with_min_max_grad") | |||||
| def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, quant_max, | |||||
| kernel_name="fake_quant_with_min_max_grad"): | |||||
| """FakeQuantWithMinMaxGrad""" | |||||
| @fusion_manager.register("fake_quant_per_layer_grad") | |||||
| def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quant_max, | |||||
| kernel_name="fake_quant_per_layer_grad"): | |||||
| """FakeQuantPerLayerGrad""" | |||||
| shape = te.lang.cce.util.shape_to_list(x.shape) | shape = te.lang.cce.util.shape_to_list(x.shape) | ||||
| shape_min = te.lang.cce.util.shape_to_list(min_val.shape) | shape_min = te.lang.cce.util.shape_to_list(min_val.shape) | ||||
| quant_min = tvm.const(quant_min, x.dtype) | quant_min = tvm.const(quant_min, x.dtype) | ||||
| @@ -89,10 +89,13 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q | |||||
| quant_max = te.lang.cce.broadcast(quant_max, shape_min) | quant_max = te.lang.cce.broadcast(quant_max, shape_min) | ||||
| # CalNudge(NudgeMinMax) | # CalNudge(NudgeMinMax) | ||||
| scale = te.lang.cce.vdiv(te.lang.cce.vsub(max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) | |||||
| scale = te.lang.cce.vdiv(te.lang.cce.vsub( | |||||
| max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) | |||||
| zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale)) | zp_from_min = te.lang.cce.vsub(quant_min, te.lang.cce.vdiv(min_val, scale)) | ||||
| # Nudge zero point | # Nudge zero point | ||||
| nudge_zp = te.lang.cce.round(te.lang.cce.vmin(quant_max, te.lang.cce.vmax(quant_min, zp_from_min))) | |||||
| nudge_zp_ = te.lang.cce.vmin( | |||||
| quant_max, te.lang.cce.vmax(quant_min, zp_from_min)) | |||||
| nudge_zp = te.lang.cce.floor(te.lang.cce.vadds(nudge_zp_, 0.5)) | |||||
| nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale) | nudge_min = te.lang.cce.vmul(te.lang.cce.vsub(quant_min, nudge_zp), scale) | ||||
| nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale) | nudge_max = te.lang.cce.vmul(te.lang.cce.vsub(quant_max, nudge_zp), scale) | ||||
| nudge_min = te.lang.cce.broadcast(nudge_min, shape) | nudge_min = te.lang.cce.broadcast(nudge_min, shape) | ||||
| @@ -106,11 +109,11 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q | |||||
| return res | return res | ||||
| @util.check_input_type(dict, dict, dict, dict, dict, int, int, bool, bool, str) | |||||
| def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, | |||||
| num_bits, quant_delay, symmetric, narrow_range, | |||||
| kernel_name="fake_quant_with_min_max_grad"): | |||||
| """FakeQuantWithMinMaxGrad""" | |||||
| @util.check_input_type(dict, dict, dict, dict, dict, int, bool, bool, str) | |||||
| def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx, | |||||
| num_bits, symmetric, narrow_range, | |||||
| kernel_name="fake_quant_per_layer_grad"): | |||||
| """FakeQuantPerLayerGrad""" | |||||
| input_shape = x.get("shape") | input_shape = x.get("shape") | ||||
| input_dtype = x.get("dtype") | input_dtype = x.get("dtype") | ||||
| min_shape = min_val.get("ori_shape") | min_shape = min_val.get("ori_shape") | ||||
| @@ -152,8 +155,8 @@ def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx, | |||||
| input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) | input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) | ||||
| min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) | min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) | ||||
| max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) | max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) | ||||
| res = fake_quant_with_min_max_grad_compute(dout_data, input_data, min_data, max_data, quant_min, | |||||
| quant_max, kernel_name) | |||||
| res = fake_quant_per_layer_grad_compute(dout_data, input_data, min_data, max_data, quant_min, | |||||
| quant_max, kernel_name) | |||||
| with tvm.target.cce(): | with tvm.target.cce(): | ||||
| sch = generic.auto_schedule(res) | sch = generic.auto_schedule(res) | ||||
| @@ -20,10 +20,12 @@ from ..._checkparam import Rel | |||||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | from ..primitive import PrimitiveWithInfer, prim_attr_register | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| __all__ = ["FakeQuantWithMinMax", | |||||
| "FakeQuantWithMinMaxGrad", | |||||
| "FakeQuantWithMinMaxPerChannel", | |||||
| "FakeQuantWithMinMaxPerChannelGrad", | |||||
| __all__ = ["FakeQuantPerLayer", | |||||
| "FakeQuantPerLayerGrad", | |||||
| "FakeQuantPerChannel", | |||||
| "FakeQuantPerChannelGrad", | |||||
| "FakeQuantMinMaxPerLayerUpdate", | |||||
| "FakeQuantMinMaxPerChannelUpdate", | |||||
| "BatchNormFold", | "BatchNormFold", | ||||
| "BatchNormFoldGrad", | "BatchNormFoldGrad", | ||||
| "CorrectionMul", | "CorrectionMul", | ||||
| @@ -36,11 +38,10 @@ __all__ = ["FakeQuantWithMinMax", | |||||
| "BatchNormFold2_D", | "BatchNormFold2_D", | ||||
| "BatchNormFold2GradD", | "BatchNormFold2GradD", | ||||
| "BatchNormFold2GradReduce", | "BatchNormFold2GradReduce", | ||||
| "FakeQuantWithMinMaxUpdate", | |||||
| ] | ] | ||||
| class FakeQuantWithMinMax(PrimitiveWithInfer): | |||||
| class FakeQuantPerLayer(PrimitiveWithInfer): | |||||
| r""" | r""" | ||||
| Simulate the quantize and dequantize operations in training time. | Simulate the quantize and dequantize operations in training time. | ||||
| @@ -67,49 +68,67 @@ class FakeQuantWithMinMax(PrimitiveWithInfer): | |||||
| >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) | >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) | ||||
| >>> min_tensor = Tensor(np.array([-6]), mstype.float32) | >>> min_tensor = Tensor(np.array([-6]), mstype.float32) | ||||
| >>> max_tensor = Tensor(np.array([6]), mstype.float32) | >>> max_tensor = Tensor(np.array([6]), mstype.float32) | ||||
| >>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor) | |||||
| >>> output_tensor = P.FakeQuantPerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor) | |||||
| """ | """ | ||||
| support_quant_bit = [4, 7, 8] | support_quant_bit = [4, 7, 8] | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False, | |||||
| def __init__(self, | |||||
| num_bits=8, | |||||
| ema=False, | |||||
| ema_decay=0.999, | |||||
| quant_delay=0, | |||||
| symmetric=False, | |||||
| narrow_range=False, | |||||
| training=True): | training=True): | ||||
| """init FakeQuantWithMinMax OP""" | |||||
| """init FakeQuantPerLayer OP""" | |||||
| if num_bits not in self.support_quant_bit: | if num_bits not in self.support_quant_bit: | ||||
| raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") | |||||
| raise ValueError( | |||||
| f"For '{self.name}' attr \'num_bits\' is not support.") | |||||
| if ema and not ema_decay: | if ema and not ema_decay: | ||||
| raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | |||||
| raise ValueError( | |||||
| f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | |||||
| self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | ||||
| self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) | |||||
| self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) | |||||
| self.training = validator.check_value_type('training', training, (bool,), self.name) | |||||
| self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) | |||||
| self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) | |||||
| self.symmetric = validator.check_value_type( | |||||
| 'symmetric', symmetric, (bool,), self.name) | |||||
| self.narrow_range = validator.check_value_type( | |||||
| 'narrow_range', narrow_range, (bool,), self.name) | |||||
| self.training = validator.check_value_type( | |||||
| 'training', training, (bool,), self.name) | |||||
| self.ema_decay = validator.check_number_range( | |||||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.num_bits = validator.check_integer( | |||||
| 'num_bits', num_bits, 0, Rel.GT, self.name) | |||||
| self.quant_delay = validator.check_value_type( | |||||
| 'quant_delay', quant_delay, (int,), self.name) | |||||
| self.init_prim_io_names(inputs=['x', 'min', 'max'], | self.init_prim_io_names(inputs=['x', 'min', 'max'], | ||||
| outputs=['out']) | outputs=['out']) | ||||
| def infer_shape(self, x_shape, min_shape, max_shape): | def infer_shape(self, x_shape, min_shape, max_shape): | ||||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | ||||
| validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) | |||||
| validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name) | |||||
| validator.check("min shape", min_shape, "max shape", | |||||
| max_shape, Rel.EQ, self.name) | |||||
| validator.check_integer("min rank", len( | |||||
| min_shape), 1, Rel.EQ, self.name) | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type, min_type, max_type): | def infer_dtype(self, x_type, min_type, max_type): | ||||
| valid_types = (mstype.float16, mstype.float32) | valid_types = (mstype.float16, mstype.float32) | ||||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | ||||
| validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"min": min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"max": max_type}, valid_types, self.name) | |||||
| return x_type | return x_type | ||||
| class FakeQuantWithMinMaxGrad(PrimitiveWithInfer): | |||||
| class FakeQuantPerLayerGrad(PrimitiveWithInfer): | |||||
| r""" | r""" | ||||
| Performs grad of FakeQuantWithMinMax operation. | |||||
| Performs grad of FakeQuantPerLayerGrad operation. | |||||
| Examples: | Examples: | ||||
| >>> fake_min_max_grad = P.FakeQuantWithMinMaxGrad() | |||||
| >>> fake_min_max_grad = P.FakeQuantPerLayerGrad() | |||||
| >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32) | >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32) | ||||
| >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32) | >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32) | ||||
| >>> _min = Tensor(np.array([-4]), mindspore.float32) | >>> _min = Tensor(np.array([-4]), mindspore.float32) | ||||
| @@ -119,32 +138,48 @@ class FakeQuantWithMinMaxGrad(PrimitiveWithInfer): | |||||
| support_quant_bit = [4, 7, 8] | support_quant_bit = [4, 7, 8] | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False): | |||||
| def __init__(self, | |||||
| num_bits=8, | |||||
| quant_delay=0, | |||||
| symmetric=False, | |||||
| narrow_range=False): | |||||
| if num_bits not in self.support_quant_bit: | if num_bits not in self.support_quant_bit: | ||||
| raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") | |||||
| self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) | |||||
| self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) | |||||
| self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) | |||||
| self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) | |||||
| self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) | |||||
| raise ValueError( | |||||
| f"For '{self.name}' attr \'num_bits\' is not support.") | |||||
| self.num_bits = validator.check_integer( | |||||
| 'num_bits', num_bits, 0, Rel.GT, self.name) | |||||
| self.quant_delay = validator.check_value_type( | |||||
| 'quant_delay', quant_delay, (int,), self.name) | |||||
| self.symmetric = validator.check_value_type( | |||||
| 'symmetric', symmetric, (bool,), self.name) | |||||
| self.narrow_range = validator.check_value_type( | |||||
| 'narrow_range', narrow_range, (bool,), self.name) | |||||
| self.init_prim_io_names( | |||||
| inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) | |||||
| def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): | def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): | ||||
| validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name) | |||||
| validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) | |||||
| validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name) | |||||
| validator.check("dout shape", dout_shape, "x shape", | |||||
| x_shape, Rel.EQ, self.name) | |||||
| validator.check("min shape", min_shape, "max shape", | |||||
| max_shape, Rel.EQ, self.name) | |||||
| validator.check_integer("min rank", len( | |||||
| min_shape), 1, Rel.EQ, self.name) | |||||
| return dout_shape | return dout_shape | ||||
| def infer_dtype(self, dout_type, x_type, min_type, max_type): | def infer_dtype(self, dout_type, x_type, min_type, max_type): | ||||
| valid_types = (mstype.float16, mstype.float32) | valid_types = (mstype.float16, mstype.float32) | ||||
| validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"dout": dout_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | ||||
| validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"min": min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"max": max_type}, valid_types, self.name) | |||||
| return dout_type | return dout_type | ||||
| class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): | |||||
| class FakeQuantPerChannel(PrimitiveWithInfer): | |||||
| r""" | r""" | ||||
| Simulate the quantize and dequantize operations in training time base on per channel. | Simulate the quantize and dequantize operations in training time base on per channel. | ||||
| @@ -168,53 +203,73 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): | |||||
| - Tensor, has the same type as input. | - Tensor, has the same type as input. | ||||
| Examples: | Examples: | ||||
| >>> fake_quant = P.FakeQuantWithMinMaxPerChannel() | |||||
| >>> fake_quant = P.FakeQuantPerChannel() | |||||
| >>> input_x = Tensor(np.array([3, 4, 5, -2, -3, -1]).reshape(3, 2), mindspore.float32) | >>> input_x = Tensor(np.array([3, 4, 5, -2, -3, -1]).reshape(3, 2), mindspore.float32) | ||||
| >>> _min = Tensor(np.linspace(-2, 2, 12).reshape(3, 2, 2), mindspore.float32) | >>> _min = Tensor(np.linspace(-2, 2, 12).reshape(3, 2, 2), mindspore.float32) | ||||
| >>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32) | >>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32) | ||||
| >>> result = fake_quant(input_x, _min, _max) | >>> result = fake_quant(input_x, _min, _max) | ||||
| """ | """ | ||||
| support_quant_bit = [4, 7, 8] | support_quant_bit = [4, 7, 8] | ||||
| channel_axis = 0 | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False, | |||||
| training=True): | |||||
| """init FakeQuantWithMinMaxPerChannel OP""" | |||||
| def __init__(self, | |||||
| num_bits=8, | |||||
| ema=False, | |||||
| ema_decay=0.999, | |||||
| quant_delay=0, | |||||
| symmetric=False, | |||||
| narrow_range=False, | |||||
| training=True, | |||||
| channel_axis=1): | |||||
| """init FakeQuantPerChannel OP""" | |||||
| if num_bits not in self.support_quant_bit: | if num_bits not in self.support_quant_bit: | ||||
| raise ValueError(f"For '{self.name}' Attr \'num_bits\' is not support.") | |||||
| raise ValueError( | |||||
| f"For '{self.name}' Attr \'num_bits\' is not support.") | |||||
| if ema and not ema_decay: | if ema and not ema_decay: | ||||
| raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | |||||
| raise ValueError( | |||||
| f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | |||||
| self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | ||||
| self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) | |||||
| self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) | |||||
| self.training = validator.check_value_type('training', training, (bool,), self.name) | |||||
| self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) | |||||
| self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) | |||||
| self.symmetric = validator.check_value_type( | |||||
| 'symmetric', symmetric, (bool,), self.name) | |||||
| self.narrow_range = validator.check_value_type( | |||||
| 'narrow_range', narrow_range, (bool,), self.name) | |||||
| self.training = validator.check_value_type( | |||||
| 'training', training, (bool,), self.name) | |||||
| self.ema_decay = validator.check_number_range( | |||||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.num_bits = validator.check_integer( | |||||
| 'num_bits', num_bits, 0, Rel.GT, self.name) | |||||
| self.quant_delay = validator.check_value_type( | |||||
| 'quant_delay', quant_delay, (int,), self.name) | |||||
| self.channel_axis = validator.check_integer( | |||||
| 'channel_axis', channel_axis, 0, Rel.GE, self.name) | |||||
| self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) | self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) | ||||
| def infer_shape(self, x_shape, min_shape, max_shape): | def infer_shape(self, x_shape, min_shape, max_shape): | ||||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | ||||
| validator.check_integer("min shape[0]", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) | |||||
| validator.check_integer("max shape[0]", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) | |||||
| validator.check_integer( | |||||
| "min shape[0]", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) | |||||
| validator.check_integer( | |||||
| "max shape[0]", max_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_type, min_type, max_type): | def infer_dtype(self, x_type, min_type, max_type): | ||||
| valid_types = (mstype.float16, mstype.float32) | valid_types = (mstype.float16, mstype.float32) | ||||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | ||||
| validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"min": min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"max": max_type}, valid_types, self.name) | |||||
| return x_type | return x_type | ||||
| class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): | |||||
| class FakeQuantPerChannelGrad(PrimitiveWithInfer): | |||||
| r""" | r""" | ||||
| Performs grad of FakeQuantWithMinMaxPerChannel operation. | |||||
| Performs grad of FakeQuantPerChannelGrad operation. | |||||
| Examples: | Examples: | ||||
| >>> fqmmpc_grad = P.FakeQuantWithMinMaxPerChannelGrad() | |||||
| >>> fqmmpc_grad = P.FakeQuantPerChannelGrad() | |||||
| >>> input_x = Tensor(np.random.randint(-4, 4, (2, 3, 4)), mindspore.float32) | >>> input_x = Tensor(np.random.randint(-4, 4, (2, 3, 4)), mindspore.float32) | ||||
| >>> dout = Tensor(np.random.randint(-2, 2, (2, 3, 4)), mindspore.float32) | >>> dout = Tensor(np.random.randint(-2, 2, (2, 3, 4)), mindspore.float32) | ||||
| >>> _min = Tensor(np.random.randint(-8, 2, (2, 3, 4)), mindspore.float32) | >>> _min = Tensor(np.random.randint(-8, 2, (2, 3, 4)), mindspore.float32) | ||||
| @@ -224,16 +279,29 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): | |||||
| support_quant_bit = [4, 7, 8] | support_quant_bit = [4, 7, 8] | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, num_bits=8, quant_delay=0, symmetric=False, narrow_range=False): | |||||
| """init FakeQuantWithMinMaxPerChannel Fill""" | |||||
| def __init__(self, | |||||
| num_bits=8, | |||||
| quant_delay=0, | |||||
| symmetric=False, | |||||
| narrow_range=False, | |||||
| channel_axis=1): | |||||
| """init FakeQuantPerChannelGrad Fill""" | |||||
| if num_bits not in self.support_quant_bit: | if num_bits not in self.support_quant_bit: | ||||
| raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") | |||||
| self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) | |||||
| self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) | |||||
| self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) | |||||
| self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) | |||||
| self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) | |||||
| raise ValueError( | |||||
| f"For '{self.name}' attr \'num_bits\' is not support.") | |||||
| self.num_bits = validator.check_integer( | |||||
| 'num_bits', num_bits, 0, Rel.GT, self.name) | |||||
| self.quant_delay = validator.check_value_type( | |||||
| 'quant_delay', quant_delay, (int,), self.name) | |||||
| self.symmetric = validator.check_value_type( | |||||
| 'symmetric', symmetric, (bool,), self.name) | |||||
| self.narrow_range = validator.check_value_type( | |||||
| 'narrow_range', narrow_range, (bool,), self.name) | |||||
| self.channel_axis = validator.check_integer( | |||||
| 'channel axis', channel_axis, 0, Rel.GE, self.name) | |||||
| self.init_prim_io_names( | |||||
| inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) | |||||
| def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): | def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): | ||||
| validator.check("dout shape", dout_shape, "x shape", x_shape) | validator.check("dout shape", dout_shape, "x shape", x_shape) | ||||
| @@ -242,10 +310,13 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): | |||||
| def infer_dtype(self, dout_type, x_type, min_type, max_type): | def infer_dtype(self, dout_type, x_type, min_type, max_type): | ||||
| valid_types = (mstype.float16, mstype.float32) | valid_types = (mstype.float16, mstype.float32) | ||||
| validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"dout": dout_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | ||||
| validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"min": min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"max": max_type}, valid_types, self.name) | |||||
| return dout_type | return dout_type | ||||
| @@ -744,17 +815,14 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer): | |||||
| return dout_type, dout_type | return dout_type, dout_type | ||||
| class FakeQuantWithMinMaxUpdate(PrimitiveWithInfer): | |||||
| class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer): | |||||
| r""" | r""" | ||||
| Simulate the quantize and dequantize operations in training time. | |||||
| Update min and max value for fake quant per layer op. | |||||
| Args: | Args: | ||||
| num_bits (int) : Number bits for aware quantilization. Default: 8. | num_bits (int) : Number bits for aware quantilization. Default: 8. | ||||
| ema (bool): Use EMA algorithm update value min and max. Default: False. | ema (bool): Use EMA algorithm update value min and max. Default: False. | ||||
| ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. | ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. | ||||
| quant_delay (int): Quantilization delay parameter. Before delay step in training time not update | |||||
| simulate aware quantize funcion. After delay step in training time begin simulate the aware | |||||
| quantize funcion. Default: 0. | |||||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | ||||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | ||||
| training (bool): Training the network or not. Default: True. | training (bool): Training the network or not. Default: True. | ||||
| @@ -776,36 +844,121 @@ class FakeQuantWithMinMaxUpdate(PrimitiveWithInfer): | |||||
| support_quant_bit = [4, 7, 8] | support_quant_bit = [4, 7, 8] | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False, | |||||
| def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, | |||||
| training=True): | training=True): | ||||
| """init FakeQuantWithMinMax OP""" | |||||
| """init FakeQuantMinMaxPerLayerUpdate OP""" | |||||
| from mindspore.ops._op_impl._custom_op import correction_mul, correction_mul_grad | from mindspore.ops._op_impl._custom_op import correction_mul, correction_mul_grad | ||||
| from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max, fake_quant_with_min_max_grad | from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max, fake_quant_with_min_max_grad | ||||
| from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max_update | from mindspore.ops._op_impl._custom_op import fake_quant_with_min_max_update | ||||
| if num_bits not in self.support_quant_bit: | if num_bits not in self.support_quant_bit: | ||||
| raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") | |||||
| raise ValueError( | |||||
| f"For '{self.name}' attr \'num_bits\' is not support.") | |||||
| if ema and not ema_decay: | if ema and not ema_decay: | ||||
| raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | |||||
| raise ValueError( | |||||
| f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | |||||
| self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | ||||
| self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) | |||||
| self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) | |||||
| self.training = validator.check_value_type('training', training, (bool,), self.name) | |||||
| self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) | |||||
| self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) | |||||
| self.symmetric = validator.check_value_type( | |||||
| 'symmetric', symmetric, (bool,), self.name) | |||||
| self.narrow_range = validator.check_value_type( | |||||
| 'narrow_range', narrow_range, (bool,), self.name) | |||||
| self.training = validator.check_value_type( | |||||
| 'training', training, (bool,), self.name) | |||||
| self.ema_decay = validator.check_number_range( | |||||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.num_bits = validator.check_integer( | |||||
| 'num_bits', num_bits, 0, Rel.GT, self.name) | |||||
| self.init_prim_io_names(inputs=['x', 'min', 'max'], | self.init_prim_io_names(inputs=['x', 'min', 'max'], | ||||
| outputs=['min_up', 'max_up']) | outputs=['min_up', 'max_up']) | ||||
| def infer_shape(self, x_shape, min_shape, max_shape): | def infer_shape(self, x_shape, min_shape, max_shape): | ||||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | ||||
| validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) | |||||
| validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name) | |||||
| validator.check("min shape", min_shape, "max shape", | |||||
| max_shape, Rel.EQ, self.name) | |||||
| validator.check_integer("min rank", len( | |||||
| min_shape), 1, Rel.EQ, self.name) | |||||
| return min_shape, max_shape | return min_shape, max_shape | ||||
| def infer_dtype(self, x_type, min_type, max_type): | def infer_dtype(self, x_type, min_type, max_type): | ||||
| valid_types = (mstype.float16, mstype.float32) | valid_types = (mstype.float16, mstype.float32) | ||||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | ||||
| validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"min": min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"max": max_type}, valid_types, self.name) | |||||
| return min_type, max_type | |||||
| class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer): | |||||
| r""" | |||||
| Update min and max value for fake quant per layer op. | |||||
| Args: | |||||
| num_bits (int) : Number bits for aware quantilization. Default: 8. | |||||
| ema (bool): Use EMA algorithm update value min and max. Default: False. | |||||
| ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. | |||||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||||
| training (bool): Training the network or not. Default: True. | |||||
| channel_axis (int): Channel asis for per channel compute. Default: 1. | |||||
| Inputs: | |||||
| - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. | |||||
| - **min** (Tensor) : Value of the min range of the input data x. | |||||
| - **max** (Tensor) : Value of the max range of the input data x. | |||||
| Outputs: | |||||
| - Tensor: Simulate quantize tensor of x. | |||||
| Examples: | |||||
| >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) | |||||
| >>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) | |||||
| >>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) | |||||
| >>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(x, min, max) | |||||
| """ | |||||
| support_quant_bit = [4, 7, 8] | |||||
| @prim_attr_register | |||||
| def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, | |||||
| training=True, channel_axis=1): | |||||
| """init FakeQuantPerChannelUpdate OP for Ascend""" | |||||
| if num_bits not in self.support_quant_bit: | |||||
| raise ValueError( | |||||
| f"For '{self.name}' attr \'num_bits\' is not support.") | |||||
| if ema and not ema_decay: | |||||
| raise ValueError( | |||||
| f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | |||||
| self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | |||||
| self.symmetric = validator.check_value_type( | |||||
| 'symmetric', symmetric, (bool,), self.name) | |||||
| self.narrow_range = validator.check_value_type( | |||||
| 'narrow_range', narrow_range, (bool,), self.name) | |||||
| self.training = validator.check_value_type( | |||||
| 'training', training, (bool,), self.name) | |||||
| self.ema_decay = validator.check_number_range( | |||||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.num_bits = validator.check_integer( | |||||
| 'num_bits', num_bits, 0, Rel.GT, self.name) | |||||
| self.channel_axis = validator.check_integer( | |||||
| 'channel axis', channel_axis, 0, Rel.GE, self.name) | |||||
| self.init_prim_io_names( | |||||
| inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up']) | |||||
| def infer_shape(self, x_shape, min_shape, max_shape): | |||||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) | |||||
| validator.check("min shape", min_shape, "max shape", | |||||
| max_shape, Rel.EQ, self.name) | |||||
| validator.check_integer("min rank", len( | |||||
| min_shape), 1, Rel.EQ, self.name) | |||||
| return min_shape, max_shape | |||||
| def infer_dtype(self, x_type, min_type, max_type): | |||||
| valid_types = (mstype.float16, mstype.float32) | |||||
| validator.check_tensor_type_same( | |||||
| {"x": x_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"min": min_type}, valid_types, self.name) | |||||
| validator.check_tensor_type_same( | |||||
| {"max": max_type}, valid_types, self.name) | |||||
| return min_type, max_type | return min_type, max_type | ||||