Merge pull request !2451 from 王东旭/mastertags/v0.6.0-beta
| @@ -23,35 +23,24 @@ | |||
| #include "device/gpu/cuda_common.h" | |||
| __global__ void UpdateInputMinMaxPerLayerWithEMA(const float *input_min, const float *input_max, float *output_min, | |||
| float *output_max, const float min, const float max, const float decay, | |||
| const float symmetric) { | |||
| float *output_max, const float min, const float max, | |||
| const float decay) { | |||
| output_min[0] = decay * (min) + (1 - decay) * (input_min[0]); | |||
| output_min[0] = input_min[0] > 0 ? 0 : input_min[0]; | |||
| output_max[0] = decay * (max) + (1 - decay) * (input_max[0]); | |||
| output_max[0] = input_max[0] < 0 ? 0 : input_max[0]; | |||
| if (symmetric) { | |||
| output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0]; | |||
| output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0]; | |||
| } | |||
| return; | |||
| } | |||
| __global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max, | |||
| const float symmetric) { | |||
| __global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max) { | |||
| output_min[0] = min > 0 ? 0 : min; | |||
| output_max[0] = max < 0 ? 0 : max; | |||
| if (symmetric) { | |||
| output_max[0] = abs(output_min[0]) < output_max[0] ? output_max[0] : -output_min[0]; | |||
| output_min[0] = abs(output_min[0]) < output_max[0] ? -output_max[0] : output_min[0]; | |||
| } | |||
| return; | |||
| } | |||
| __global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, | |||
| float *output_max, int channels, int per_channel_nums, bool ema, | |||
| float ema_decay, bool symmetric) { | |||
| float ema_decay) { | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) { | |||
| thrust::pair<float *, float *> sum = | |||
| thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1)); | |||
| @@ -64,27 +53,21 @@ __global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, floa | |||
| } | |||
| output_min[i] = input_min[i] > 0 ? 0 : input_min[i]; | |||
| output_max[i] = input_max[i] < 0 ? 0 : input_max[i]; | |||
| if (symmetric) { | |||
| output_max[i] = abs(output_min[i]) < output_max[i] ? output_max[i] : -output_min[i]; | |||
| output_min[i] = abs(output_min[i]) < output_max[i] ? -output_max[i] : output_min[i]; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max, | |||
| const int total_num, const int channel_num, const float ema_decay, const bool ema, | |||
| const bool symmetric, cudaStream_t cuda_stream) { | |||
| cudaStream_t cuda_stream) { | |||
| int per_channel_num = total_num / channel_num; | |||
| UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>( | |||
| input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay, symmetric); | |||
| input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay); | |||
| return; | |||
| } | |||
| void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max, | |||
| const int total_num, const float ema_decay, const bool ema, const bool symmetric, | |||
| cudaStream_t cuda_stream) { | |||
| const int total_num, const float ema_decay, const bool ema, cudaStream_t cuda_stream) { | |||
| float minel = 0.f; | |||
| float maxel = 0.f; | |||
| auto policy = thrust::cuda::par.on(cuda_stream); | |||
| @@ -96,9 +79,9 @@ void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float * | |||
| if (ema) { | |||
| UpdateInputMinMaxPerLayerWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, output_min, output_max, minel, | |||
| maxel, ema_decay, symmetric); | |||
| maxel, ema_decay); | |||
| } else { | |||
| UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel, symmetric); | |||
| UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel); | |||
| } | |||
| return; | |||
| } | |||
| @@ -21,10 +21,9 @@ | |||
| void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max, | |||
| const int total_num, const int channel_num, const float ema_decay, const bool ema, | |||
| const bool symmetric, cudaStream_t cuda_stream); | |||
| cudaStream_t cuda_stream); | |||
| void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max, | |||
| const int size, const float ema_decay, const bool ema, const bool symmetric, | |||
| cudaStream_t cuda_stream); | |||
| const int size, const float ema_decay, const bool ema, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_ | |||
| @@ -24,16 +24,7 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MinMaxUpdatePerChannelGpuKernel::MinMaxUpdatePerChannelGpuKernel() | |||
| : input_size_(0), | |||
| num_bits_(0), | |||
| quant_min_(0), | |||
| quant_max_(0), | |||
| quant_num_(1), | |||
| ema_(false), | |||
| ema_decay_(0), | |||
| num_channels_(0), | |||
| narrow_range_(false), | |||
| symmetric_(false) {} | |||
| : input_size_(0), quant_num_(1), ema_(false), ema_decay_(0), num_channels_(0) {} | |||
| const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; } | |||
| @@ -54,22 +45,8 @@ bool MinMaxUpdatePerChannelGpuKernel::Init(const CNodePtr &kernel_node) { | |||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; | |||
| } | |||
| num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); | |||
| ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); | |||
| ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); | |||
| symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); | |||
| narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); | |||
| if (num_bits_ <= 2 || num_bits_ >= 16) { | |||
| MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; | |||
| } | |||
| // quant min and max | |||
| quant_min_ = 0; | |||
| quant_max_ = (1 << num_bits_) - 1; | |||
| if (narrow_range_) { | |||
| quant_min_++; | |||
| } | |||
| // init size | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| @@ -110,7 +87,7 @@ bool MinMaxUpdatePerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inpu | |||
| // calculate the input min and max according by the parameter ema and ema_decay. | |||
| CalMinMaxPerChannel(input, input_min, input_max, output_min, output_max, input_size_ / sizeof(float), num_channels_, | |||
| ema_decay_, ema_, symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| ema_decay_, ema_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| @@ -44,15 +44,10 @@ class MinMaxUpdatePerChannelGpuKernel : public GpuKernel { | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| int num_bits_; | |||
| float quant_min_; | |||
| float quant_max_; | |||
| int quant_num_; | |||
| bool ema_; | |||
| float ema_decay_; | |||
| int num_channels_; | |||
| bool narrow_range_; | |||
| bool symmetric_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -24,15 +24,7 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MinMaxUpdatePerLayerGpuKernel::MinMaxUpdatePerLayerGpuKernel() | |||
| : input_size_(0), | |||
| num_bits_(0), | |||
| quant_min_(0), | |||
| quant_max_(0), | |||
| quant_num_(1), | |||
| ema_(false), | |||
| ema_decay_(0), | |||
| narrow_range_(false), | |||
| symmetric_(false) {} | |||
| : input_size_(0), quant_num_(1), ema_(false), ema_decay_(0) {} | |||
| const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; } | |||
| @@ -51,22 +43,8 @@ bool MinMaxUpdatePerLayerGpuKernel::Init(const CNodePtr &kernel_node) { | |||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output."; | |||
| } | |||
| num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); | |||
| ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); | |||
| ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); | |||
| symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric")); | |||
| narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range")); | |||
| if (num_bits_ <= 2 || num_bits_ >= 16) { | |||
| MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16."; | |||
| } | |||
| // quant min and max | |||
| quant_min_ = 0; | |||
| quant_max_ = (1 << num_bits_) - 1; | |||
| if (narrow_range_) { | |||
| quant_min_++; | |||
| } | |||
| // init size | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| @@ -104,7 +82,7 @@ bool MinMaxUpdatePerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs | |||
| MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input min or input max is null."; | |||
| } | |||
| CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_, symmetric_, | |||
| CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| @@ -44,14 +44,9 @@ class MinMaxUpdatePerLayerGpuKernel : public GpuKernel { | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| int num_bits_; | |||
| float quant_min_; | |||
| float quant_max_; | |||
| int quant_num_; | |||
| bool ema_; | |||
| float ema_decay_; | |||
| bool narrow_range_; | |||
| bool symmetric_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -276,15 +276,15 @@ class FakeQuantWithMinMax(Cell): | |||
| Args: | |||
| min_init (int, float): The dimension of channel or 1(layer). Default: -6. | |||
| max_init (int, float): The dimension of channel or 1(layer). Default: 6. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| ema (bool): Exponential Moving Average algorithm update min and max. Default: False. | |||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | |||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||
| channel_axis (int): Quantization by channel axis. Default: 1. | |||
| out_channels (int): declarate the min and max channel size, Default: 1. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| num_channels (int): declarate the min and max channel size, Default: 1. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| Inputs: | |||
| - **x** (Tensor) - The input of FakeQuantWithMinMax. | |||
| @@ -301,15 +301,15 @@ class FakeQuantWithMinMax(Cell): | |||
| def __init__(self, | |||
| min_init=-6, | |||
| max_init=6, | |||
| num_bits=8, | |||
| ema=False, | |||
| ema_decay=0.999, | |||
| per_channel=False, | |||
| channel_axis=1, | |||
| out_channels=1, | |||
| quant_delay=0, | |||
| num_channels=1, | |||
| num_bits=8, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| narrow_range=False, | |||
| quant_delay=0): | |||
| """init FakeQuantWithMinMax layer""" | |||
| super(FakeQuantWithMinMax, self).__init__() | |||
| self.min_init = min_init | |||
| @@ -318,7 +318,7 @@ class FakeQuantWithMinMax(Cell): | |||
| self.ema = ema | |||
| self.ema_decay = ema_decay | |||
| self.per_channel = per_channel | |||
| self.out_channels = out_channels | |||
| self.num_channels = num_channels | |||
| self.channel_axis = channel_axis | |||
| self.quant_delay = quant_delay | |||
| self.symmetric = symmetric | |||
| @@ -327,11 +327,11 @@ class FakeQuantWithMinMax(Cell): | |||
| # init tensor min and max for fake quant op | |||
| if self.per_channel: | |||
| min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32) | |||
| max_array = np.array([self.max_init for i in range(0, self.out_channels)]).astype(np.float32) | |||
| min_array = np.array([self.min_init] * self.num_channels).astype(np.float32) | |||
| max_array = np.array([self.max_init] * self.num_channels).astype(np.float32) | |||
| else: | |||
| min_array = np.array([self.min_init]).reshape(1).astype(np.float32) | |||
| max_array = np.array([self.max_init]).reshape(1).astype(np.float32) | |||
| min_array = np.array([self.min_init]).astype(np.float32) | |||
| max_array = np.array([self.max_init]).astype(np.float32) | |||
| self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) | |||
| self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) | |||
| @@ -343,57 +343,41 @@ class FakeQuantWithMinMax(Cell): | |||
| quant_fun = Q.FakeQuantPerLayer | |||
| ema_fun = Q.MinMaxUpdatePerLayer | |||
| self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay) | |||
| if self.is_ascend: | |||
| self.fake_quant = quant_fun(num_bits=self.num_bits, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range) | |||
| else: | |||
| self.fake_quant_train = quant_fun(num_bits=self.num_bits, | |||
| ema=self.ema, | |||
| ema_decay=ema_decay, | |||
| quant_delay=quant_delay, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range, | |||
| training=True) | |||
| self.fake_quant_infer = quant_fun(num_bits=self.num_bits, | |||
| ema=self.ema, | |||
| ema_decay=ema_decay, | |||
| quant_delay=quant_delay, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range, | |||
| training=False) | |||
| self.ema_update = ema_fun(num_bits=self.num_bits, | |||
| ema=self.ema, | |||
| ema_decay=self.ema_decay, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range) | |||
| narrow_range=self.narrow_range) | |||
| self.fake_quant_infer = self.fake_quant_train | |||
| else: | |||
| quant_fun = partial(quant_fun, | |||
| ema=self.ema, | |||
| ema_decay=ema_decay, | |||
| num_bits=self.num_bits, | |||
| symmetric=self.symmetric, | |||
| narrow_range=self.narrow_range, | |||
| quant_delay=quant_delay) | |||
| self.fake_quant_train = quant_fun(training=True) | |||
| self.fake_quant_infer = quant_fun(training=False) | |||
| def extend_repr(self): | |||
| s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \ | |||
| 'quant_delay={}, min_init={}, max_init={}'.format( | |||
| self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel, | |||
| self.channel_axis, self.out_channels, self.quant_delay, self.min_init, self.max_init) | |||
| self.channel_axis, self.num_channels, self.quant_delay, self.min_init, self.max_init) | |||
| return s | |||
| def construct(self, x): | |||
| if self.is_ascend: | |||
| if self.training: | |||
| min_up, max_up = self.ema_update(x, self.minq, self.maxq) | |||
| out = self.fake_quant(x, min_up, max_up) | |||
| P.Assign()(self.minq, min_up) | |||
| P.Assign()(self.maxq, max_up) | |||
| else: | |||
| out = self.fake_quant(x, self.minq, self.maxq) | |||
| if self.training: | |||
| min_up, max_up = self.ema_update(x, self.minq, self.maxq) | |||
| P.Assign()(self.minq, min_up) | |||
| P.Assign()(self.maxq, max_up) | |||
| out = self.fake_quant_train(x, self.minq, self.maxq) | |||
| else: | |||
| if self.training: | |||
| min_up, max_up = self.ema_update(x, self.minq, self.maxq) | |||
| out = self.fake_quant_train(x, min_up, max_up) | |||
| P.Assign()(self.minq, min_up) | |||
| P.Assign()(self.maxq, max_up) | |||
| else: | |||
| out = self.fake_quant_infer(x, self.minq, self.maxq) | |||
| out = self.fake_quant_infer(x, self.minq, self.maxq) | |||
| return out | |||
| class Conv2dBatchNormQuant(Cell): | |||
| r""" | |||
| 2D convolution with BatchNormal op folded layer. | |||
| @@ -407,8 +391,8 @@ class Conv2dBatchNormQuant(Cell): | |||
| stride (int): Specifies stride for all spatial dimensions with the same value. | |||
| pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". | |||
| padding: (int): Implicit paddings on both sides of the input. Default: 0. | |||
| eps (int): Parameters for BatchNormal. Default: 1e-5. | |||
| momentum (int): Parameters for BatchNormal op. Default: 0.997. | |||
| eps (float): Parameters for BatchNormal. Default: 1e-5. | |||
| momentum (float): Parameters for BatchNormal op. Default: 0.997. | |||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | |||
| convolution kernel. Default: 'normal'. | |||
| beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | |||
| @@ -419,13 +403,13 @@ class Conv2dBatchNormQuant(Cell): | |||
| mean vector. Default: 'zeros'. | |||
| var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the | |||
| variance vector. Default: 'ones'. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. | |||
| fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. | |||
| per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. | |||
| Inputs: | |||
| - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. | |||
| @@ -456,13 +440,13 @@ class Conv2dBatchNormQuant(Cell): | |||
| gamma_init='ones', | |||
| mean_init='zeros', | |||
| var_init='ones', | |||
| quant_delay=0, | |||
| freeze_bn=100000, | |||
| fake=True, | |||
| num_bits=8, | |||
| per_channel=False, | |||
| num_bits=8, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| narrow_range=False, | |||
| quant_delay=0, | |||
| freeze_bn=100000): | |||
| """init Conv2dBatchNormQuant layer""" | |||
| super(Conv2dBatchNormQuant, self).__init__() | |||
| self.in_channels = in_channels | |||
| @@ -519,12 +503,13 @@ class Conv2dBatchNormQuant(Cell): | |||
| self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, | |||
| max_init=6, | |||
| ema=False, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| per_channel=per_channel, | |||
| out_channels=out_channels, | |||
| channel_axis=channel_axis, | |||
| num_channels=out_channels, | |||
| num_bits=num_bits, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| narrow_range=narrow_range, | |||
| quant_delay=quant_delay) | |||
| self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) | |||
| self.correct_mul = Q.CorrectionMul(channel_axis) | |||
| if context.get_context('device_target') == "Ascend": | |||
| @@ -598,11 +583,11 @@ class Conv2dQuant(Cell): | |||
| weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. | |||
| Default: 'normal'. | |||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| Inputs: | |||
| - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. | |||
| @@ -629,11 +614,11 @@ class Conv2dQuant(Cell): | |||
| has_bias=False, | |||
| weight_init='normal', | |||
| bias_init='zeros', | |||
| quant_delay=0, | |||
| num_bits=8, | |||
| per_channel=False, | |||
| num_bits=8, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| narrow_range=False, | |||
| quant_delay=0): | |||
| super(Conv2dQuant, self).__init__() | |||
| if isinstance(kernel_size, int): | |||
| self.kernel_size = (kernel_size, kernel_size) | |||
| @@ -669,12 +654,13 @@ class Conv2dQuant(Cell): | |||
| self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, | |||
| max_init=6, | |||
| ema=False, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| per_channel=per_channel, | |||
| out_channels=out_channels, | |||
| channel_axis=0, | |||
| num_channels=out_channels, | |||
| num_bits=num_bits, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| narrow_range=narrow_range, | |||
| quant_delay=quant_delay) | |||
| def construct(self, x): | |||
| weight = self.fake_quant_weight(self.weight) | |||
| @@ -708,11 +694,11 @@ class DenseQuant(Cell): | |||
| same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. | |||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. | |||
| activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. | |||
| per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| Inputs: | |||
| - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. | |||
| @@ -734,19 +720,19 @@ class DenseQuant(Cell): | |||
| bias_init='zeros', | |||
| has_bias=True, | |||
| activation=None, | |||
| num_bits=8, | |||
| quant_delay=0, | |||
| per_channel=False, | |||
| num_bits=8, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| narrow_range=False, | |||
| quant_delay=0): | |||
| super(DenseQuant, self).__init__() | |||
| self.in_channels = check_int_positive(in_channels) | |||
| self.out_channels = check_int_positive(out_channels) | |||
| self.has_bias = check_bool(has_bias) | |||
| if isinstance(weight_init, Tensor): | |||
| if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ | |||
| weight_init.shape[1] != in_channels: | |||
| if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ | |||
| weight_init.shape()[1] != in_channels: | |||
| raise ValueError("weight_init shape error") | |||
| self.weight = Parameter(initializer( | |||
| @@ -754,7 +740,7 @@ class DenseQuant(Cell): | |||
| if self.has_bias: | |||
| if isinstance(bias_init, Tensor): | |||
| if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: | |||
| if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: | |||
| raise ValueError("bias_init shape error") | |||
| self.bias = Parameter(initializer( | |||
| @@ -768,12 +754,13 @@ class DenseQuant(Cell): | |||
| self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, | |||
| max_init=6, | |||
| ema=False, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| per_channel=per_channel, | |||
| out_channels=out_channels, | |||
| channel_axis=0, | |||
| num_channels=out_channels, | |||
| num_bits=num_bits, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| narrow_range=narrow_range, | |||
| quant_delay=quant_delay) | |||
| def construct(self, x): | |||
| """Use operators to construct to Dense layer.""" | |||
| @@ -796,13 +783,16 @@ class DenseQuant(Cell): | |||
| return str_info | |||
| class _QuantActivation(Cell): | |||
| r""" | |||
| Base class for Quant activation function. Add Fake Quant OP after activation OP. | |||
| """ | |||
| def get_origin(self): | |||
| raise NotImplementedError | |||
| class ReLUQuant(_QuantActivation): | |||
| r""" | |||
| ReLUQuant activation function. Add Fake Quant OP after Relu OP. | |||
| @@ -810,12 +800,12 @@ class ReLUQuant(_QuantActivation): | |||
| For a more Detailed overview of ReLU op. | |||
| Args: | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | |||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| Inputs: | |||
| - **x** (Tensor) - The input of ReLUQuant. | |||
| @@ -830,22 +820,22 @@ class ReLUQuant(_QuantActivation): | |||
| """ | |||
| def __init__(self, | |||
| num_bits=8, | |||
| quant_delay=0, | |||
| ema_decay=0.999, | |||
| per_channel=False, | |||
| num_bits=8, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| narrow_range=False, | |||
| quant_delay=0): | |||
| super(ReLUQuant, self).__init__() | |||
| self.fake_quant_act = FakeQuantWithMinMax(min_init=0, | |||
| max_init=6, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| ema=True, | |||
| per_channel=per_channel, | |||
| ema_decay=ema_decay, | |||
| per_channel=per_channel, | |||
| num_bits=num_bits, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| narrow_range=narrow_range, | |||
| quant_delay=quant_delay) | |||
| self.relu = P.ReLU() | |||
| def construct(self, x): | |||
| @@ -866,12 +856,12 @@ class ReLU6Quant(_QuantActivation): | |||
| For a more Detailed overview of ReLU6 op. | |||
| Args: | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | |||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| Inputs: | |||
| - **x** (Tensor) - The input of ReLU6Quant. | |||
| @@ -886,22 +876,22 @@ class ReLU6Quant(_QuantActivation): | |||
| """ | |||
| def __init__(self, | |||
| num_bits=8, | |||
| quant_delay=0, | |||
| ema_decay=0.999, | |||
| per_channel=False, | |||
| num_bits=8, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| narrow_range=False, | |||
| quant_delay=0): | |||
| super(ReLU6Quant, self).__init__() | |||
| self.fake_quant_act = FakeQuantWithMinMax(min_init=0, | |||
| max_init=6, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| ema=True, | |||
| per_channel=per_channel, | |||
| ema_decay=ema_decay, | |||
| per_channel=per_channel, | |||
| num_bits=num_bits, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| narrow_range=narrow_range, | |||
| quant_delay=quant_delay) | |||
| self.relu6 = P.ReLU6() | |||
| def construct(self, x): | |||
| @@ -912,6 +902,7 @@ class ReLU6Quant(_QuantActivation): | |||
| def get_origin(self): | |||
| return self.relu6 | |||
| class HSwishQuant(_QuantActivation): | |||
| r""" | |||
| HSwishQuant activation function. Add Fake Quant OP after HSwish OP. | |||
| @@ -919,12 +910,12 @@ class HSwishQuant(_QuantActivation): | |||
| For a more Detailed overview of HSwish op. | |||
| Args: | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | |||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| Inputs: | |||
| - **x** (Tensor) - The input of HSwishQuant. | |||
| @@ -939,31 +930,31 @@ class HSwishQuant(_QuantActivation): | |||
| """ | |||
| def __init__(self, | |||
| num_bits=8, | |||
| quant_delay=0, | |||
| ema_decay=0.999, | |||
| per_channel=False, | |||
| num_bits=8, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| narrow_range=False, | |||
| quant_delay=0): | |||
| super(HSwishQuant, self).__init__() | |||
| self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, | |||
| max_init=6, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| ema=True, | |||
| per_channel=per_channel, | |||
| ema_decay=ema_decay, | |||
| per_channel=per_channel, | |||
| num_bits=num_bits, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| narrow_range=narrow_range, | |||
| quant_delay=quant_delay) | |||
| self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, | |||
| max_init=6, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| ema=True, | |||
| per_channel=per_channel, | |||
| ema_decay=ema_decay, | |||
| per_channel=per_channel, | |||
| num_bits=num_bits, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| narrow_range=narrow_range, | |||
| quant_delay=quant_delay) | |||
| self.act = P.HSwish() | |||
| def construct(self, x): | |||
| @@ -975,6 +966,7 @@ class HSwishQuant(_QuantActivation): | |||
| def get_origin(self): | |||
| return self.act | |||
| class HSigmoidQuant(_QuantActivation): | |||
| r""" | |||
| HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP. | |||
| @@ -982,12 +974,12 @@ class HSigmoidQuant(_QuantActivation): | |||
| For a more Detailed overview of HSigmoid op. | |||
| Args: | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | |||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| Inputs: | |||
| - **x** (Tensor) - The input of HSigmoidQuant. | |||
| @@ -1002,30 +994,31 @@ class HSigmoidQuant(_QuantActivation): | |||
| """ | |||
| def __init__(self, | |||
| num_bits=8, | |||
| quant_delay=0, | |||
| ema_decay=0.999, | |||
| per_channel=False, | |||
| num_bits=8, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| narrow_range=False, | |||
| quant_delay=0): | |||
| super(HSigmoidQuant, self).__init__() | |||
| self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, | |||
| max_init=6, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| ema=True, | |||
| ema_decay=ema_decay, | |||
| per_channel=per_channel, | |||
| num_bits=num_bits, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| narrow_range=narrow_range, | |||
| quant_delay=quant_delay) | |||
| self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, | |||
| max_init=6, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| ema=True, | |||
| per_channel=per_channel, | |||
| ema_decay=ema_decay, | |||
| per_channel=per_channel, | |||
| num_bits=num_bits, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| narrow_range=narrow_range, | |||
| quant_delay=quant_delay) | |||
| self.act = P.HSigmoid() | |||
| def construct(self, x): | |||
| @@ -1037,6 +1030,7 @@ class HSigmoidQuant(_QuantActivation): | |||
| def get_origin(self): | |||
| return self.act | |||
| class TensorAddQuant(Cell): | |||
| r""" | |||
| Add Fake Quant OP after TensorAdd OP. | |||
| @@ -1044,12 +1038,12 @@ class TensorAddQuant(Cell): | |||
| For a more Detailed overview of TensorAdd op. | |||
| Args: | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | |||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| Inputs: | |||
| - **x** (Tensor) - The input of TensorAddQuant. | |||
| @@ -1065,22 +1059,22 @@ class TensorAddQuant(Cell): | |||
| """ | |||
| def __init__(self, | |||
| num_bits=8, | |||
| quant_delay=0, | |||
| ema_decay=0.999, | |||
| per_channel=False, | |||
| num_bits=8, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| narrow_range=False, | |||
| quant_delay=0): | |||
| super(TensorAddQuant, self).__init__() | |||
| self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, | |||
| max_init=6, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| ema=True, | |||
| per_channel=per_channel, | |||
| ema_decay=ema_decay, | |||
| per_channel=per_channel, | |||
| num_bits=num_bits, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| narrow_range=narrow_range, | |||
| quant_delay=quant_delay) | |||
| self.add = P.TensorAdd() | |||
| def construct(self, x1, x2): | |||
| @@ -1096,12 +1090,12 @@ class MulQuant(Cell): | |||
| For a more Detailed overview of Mul op. | |||
| Args: | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. | |||
| per_channel (bool): Quantization granularity based on layer or on channel. Default: False. | |||
| num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| quant_delay (int): Quantization delay parameters according by global step. Default: 0. | |||
| Inputs: | |||
| - **x** (Tensor) - The input of MulQuant. | |||
| @@ -1112,22 +1106,22 @@ class MulQuant(Cell): | |||
| """ | |||
| def __init__(self, | |||
| num_bits=8, | |||
| quant_delay=0, | |||
| ema_decay=0.999, | |||
| per_channel=False, | |||
| num_bits=8, | |||
| symmetric=False, | |||
| narrow_range=False): | |||
| narrow_range=False, | |||
| quant_delay=0): | |||
| super(MulQuant, self).__init__() | |||
| self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, | |||
| max_init=6, | |||
| num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| ema=True, | |||
| per_channel=per_channel, | |||
| ema_decay=ema_decay, | |||
| per_channel=per_channel, | |||
| num_bits=num_bits, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range) | |||
| narrow_range=narrow_range, | |||
| quant_delay=quant_delay) | |||
| self.mul = P.Mul() | |||
| def construct(self, x1, x2): | |||
| @@ -1,4 +1,3 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| @@ -22,20 +21,15 @@ from topi import generic | |||
| from topi.cce import util | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| fake_quant_min_max_per_channel_update_op_info = TBERegOp("MinMaxUpdatePerChannel") \ | |||
| minmax_update_perchannel_op_info = TBERegOp("MinMaxUpdatePerChannel") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("fake_quant_min_max_per_channel_update.so") \ | |||
| .binfile_name("minmax_update_perchannel.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("fake_quant_min_max_per_channel_update") \ | |||
| .kernel_name("minmax_update_perchannel") \ | |||
| .partial_flag(True) \ | |||
| .attr("ema", "optional", "bool", "all") \ | |||
| .attr("ema_decay", "optional", "float", "all") \ | |||
| .attr("symmetric", "optional", "bool", "all") \ | |||
| .attr("narrow_range", "optional", "bool", "all") \ | |||
| .attr("training", "optional", "bool", "all") \ | |||
| .attr("num_bits", "optional", "int", "all") \ | |||
| .attr("channel_axis", "optional", "int", "all") \ | |||
| .input(0, "x", None, "required", None) \ | |||
| .input(1, "min", None, "required", None) \ | |||
| @@ -47,43 +41,46 @@ fake_quant_min_max_per_channel_update_op_info = TBERegOp("MinMaxUpdatePerChannel | |||
| .get_op_info() | |||
| @op_info_register(fake_quant_min_max_per_channel_update_op_info) | |||
| def _fake_quant_min_max_per_channel_update_tbe(): | |||
| """FakeQuantPerChannelUpdate TBE register""" | |||
| @op_info_register(minmax_update_perchannel_op_info) | |||
| def _minmax_update_perchannel_tbe(): | |||
| """MinMaxUpdatePerChannel TBE register""" | |||
| return | |||
| @fusion_manager.register("fake_quant_min_max_per_channel_update") | |||
| def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val, | |||
| ema, ema_decay, quant_min, quant_max, training, channel_axis, | |||
| kernel_name="fake_quant_min_max_per_channel_update"): | |||
| """FakeQuantPerChannelUpdate compute""" | |||
| @fusion_manager.register("minmax_update_perchannel") | |||
| def minmax_update_perchannel_compute(x, min_val, max_val, | |||
| ema, ema_decay, channel_axis): | |||
| """MinMaxUpdatePerChannel compute""" | |||
| shape_min = te.lang.cce.util.shape_to_list(min_val.shape) | |||
| if not ema: | |||
| ema_decay = 0.0 | |||
| if training: | |||
| # CalMinMax | |||
| # CalMinMax | |||
| if channel_axis == 0: | |||
| axis = [1, 2, 3, 4] | |||
| else: | |||
| axis = [0, 2, 3] | |||
| x_min = te.lang.cce.reduce_min(x, axis=axis) | |||
| x_max = te.lang.cce.reduce_max(x, axis=axis) | |||
| x_min = te.lang.cce.broadcast(x_min, shape_min) | |||
| x_max = te.lang.cce.broadcast(x_max, shape_min) | |||
| min_val = te.lang.cce.vadd(te.lang.cce.vmuls( | |||
| min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) | |||
| max_val = te.lang.cce.vadd(te.lang.cce.vmuls( | |||
| max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) | |||
| min_val = te.lang.cce.vmins(min_val, 0) | |||
| max_val = te.lang.cce.vmaxs(max_val, 0) | |||
| x_min = te.lang.cce.reduce_min(x, axis=axis) | |||
| x_max = te.lang.cce.reduce_max(x, axis=axis) | |||
| x_min = te.lang.cce.broadcast(x_min, shape_min) | |||
| x_max = te.lang.cce.broadcast(x_max, shape_min) | |||
| min_val = te.lang.cce.vadd(te.lang.cce.vmuls( | |||
| min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) | |||
| max_val = te.lang.cce.vadd(te.lang.cce.vmuls( | |||
| max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) | |||
| min_val = te.lang.cce.vmins(min_val, 0) | |||
| max_val = te.lang.cce.vmaxs(max_val, 0) | |||
| return [min_val, max_val] | |||
| @util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str) | |||
| def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up, | |||
| ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis, | |||
| kernel_name="fake_quant_min_max_per_channel_update"): | |||
| """FakeQuantPerLayer op""" | |||
| @util.check_input_type(dict, dict, dict, dict, dict, bool, float, int, str) | |||
| def minmax_update_perchannel(x, min_val, max_val, min_up, max_up, | |||
| ema, ema_decay, channel_axis, | |||
| kernel_name="minmax_update_perchannel"): | |||
| """MinMaxUpdatePerChannel op""" | |||
| x_shape = x.get("ori_shape") | |||
| x_format = x.get("format") | |||
| x_dtype = x.get("dtype") | |||
| @@ -108,21 +105,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up, | |||
| util.check_dtype_rule(min_dtype, check_list) | |||
| util.check_dtype_rule(max_dtype, check_list) | |||
| if symmetric: | |||
| quant_min = 0 - 2 ** (num_bits - 1) | |||
| quant_max = 2 ** (num_bits - 1) - 1 | |||
| if channel_axis == 0: | |||
| shape_c = min_val.get("ori_shape") | |||
| else: | |||
| quant_min = 0 | |||
| quant_max = 2 ** num_bits - 1 | |||
| if narrow_range: | |||
| quant_min = quant_min + 1 | |||
| shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]] | |||
| shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]] | |||
| input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype) | |||
| min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) | |||
| max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) | |||
| res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data, | |||
| ema, ema_decay, quant_min, quant_max, training, channel_axis, kernel_name) | |||
| res_list = minmax_update_perchannel_compute(input_data, min_data, max_data, | |||
| ema, ema_decay, channel_axis) | |||
| with tvm.target.cce(): | |||
| sch = generic.auto_schedule(res_list) | |||
| @@ -22,20 +22,15 @@ from topi import generic | |||
| from topi.cce import util | |||
| from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType | |||
| fake_quant_minmax_update_op_info = TBERegOp("MinMaxUpdatePerLayer") \ | |||
| minmax_update_perlayer_op_info = TBERegOp("MinMaxUpdatePerLayer") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .async_flag(False) \ | |||
| .binfile_name("fake_quant_minmax_update.so") \ | |||
| .binfile_name("minmax_update_perlayer.so") \ | |||
| .compute_cost(10) \ | |||
| .kernel_name("fake_quant_minmax_update") \ | |||
| .kernel_name("minmax_update_perlayer") \ | |||
| .partial_flag(True) \ | |||
| .attr("ema", "optional", "bool", "all") \ | |||
| .attr("ema_decay", "optional", "float", "all") \ | |||
| .attr("symmetric", "optional", "bool", "all") \ | |||
| .attr("narrow_range", "optional", "bool", "all") \ | |||
| .attr("training", "optional", "bool", "all") \ | |||
| .attr("num_bits", "optional", "int", "all") \ | |||
| .input(0, "x", None, "required", None) \ | |||
| .input(1, "min", None, "required", None) \ | |||
| .input(2, "max", None, "required", None) \ | |||
| @@ -46,15 +41,14 @@ fake_quant_minmax_update_op_info = TBERegOp("MinMaxUpdatePerLayer") \ | |||
| .get_op_info() | |||
| @op_info_register(fake_quant_minmax_update_op_info) | |||
| def _fake_quant_minmax_update_tbe(): | |||
| @op_info_register(minmax_update_perlayer_op_info) | |||
| def _minmax_update_perlayer_tbe(): | |||
| """MinMaxUpdatePerLayer TBE register""" | |||
| return | |||
| @fusion_manager.register("fake_quant_minmax_update") | |||
| def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training, | |||
| kernel_name="fake_quant_minmax_update"): | |||
| @fusion_manager.register("minmax_update_perlayer") | |||
| def minmax_update_perlayer_compute(x, min_val, max_val, ema, ema_decay): | |||
| """MinMaxUpdatePerLayer compute""" | |||
| shape = te.lang.cce.util.shape_to_list(x.shape) | |||
| shape_min = te.lang.cce.util.shape_to_list(min_val.shape) | |||
| @@ -62,28 +56,27 @@ def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_ | |||
| max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype) | |||
| if not ema: | |||
| ema_decay = 0.0 | |||
| if training: | |||
| # CalMinMax | |||
| axis = tuple(range(len(shape))) | |||
| x_min = te.lang.cce.reduce_min(x, axis=axis) | |||
| x_max = te.lang.cce.reduce_max(x, axis=axis) | |||
| x_min = te.lang.cce.broadcast(x_min, shape_min) | |||
| x_max = te.lang.cce.broadcast(x_max, shape_min) | |||
| min_val = te.lang.cce.vadd(te.lang.cce.vmuls( | |||
| min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) | |||
| max_val = te.lang.cce.vadd(te.lang.cce.vmuls( | |||
| max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) | |||
| min_val = te.lang.cce.vmins(min_val, 0) | |||
| max_val = te.lang.cce.vmaxs(max_val, 0) | |||
| # CalMinMax | |||
| axis = tuple(range(len(shape))) | |||
| x_min = te.lang.cce.reduce_min(x, axis=axis) | |||
| x_max = te.lang.cce.reduce_max(x, axis=axis) | |||
| x_min = te.lang.cce.broadcast(x_min, shape_min) | |||
| x_max = te.lang.cce.broadcast(x_max, shape_min) | |||
| min_val = te.lang.cce.vadd(te.lang.cce.vmuls( | |||
| min_val, ema_decay), te.lang.cce.vmuls(x_min, (1 - ema_decay))) | |||
| max_val = te.lang.cce.vadd(te.lang.cce.vmuls( | |||
| max_val, ema_decay), te.lang.cce.vmuls(x_max, (1 - ema_decay))) | |||
| min_val = te.lang.cce.vmins(min_val, 0) | |||
| max_val = te.lang.cce.vmaxs(max_val, 0) | |||
| return [min_val, max_val] | |||
| @util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str) | |||
| def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up, | |||
| ema, ema_decay, symmetric, narrow_range, training, num_bits, | |||
| kernel_name="fake_quant_minmax_update"): | |||
| """FakeQuantPerLayer op""" | |||
| @util.check_input_type(dict, dict, dict, dict, dict, bool, float, str) | |||
| def minmax_update_perlayer(x, min_val, max_val, min_up, max_up, | |||
| ema, ema_decay, kernel_name="minmax_update_perlayer"): | |||
| """MinMaxUpdatePerLayer op""" | |||
| input_shape = x.get("shape") | |||
| input_dtype = x.get("dtype") | |||
| min_shape = min_val.get("ori_shape") | |||
| @@ -112,20 +105,10 @@ def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up, | |||
| input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) | |||
| shape_min, _, _ = util.produce_shapes(min_shape, input_shape) | |||
| if symmetric: | |||
| quant_min = 0 - 2 ** (num_bits - 1) | |||
| quant_max = 2 ** (num_bits - 1) - 1 | |||
| else: | |||
| quant_min = 0 | |||
| quant_max = 2 ** num_bits - 1 | |||
| if narrow_range: | |||
| quant_min = quant_min + 1 | |||
| input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) | |||
| min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) | |||
| max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) | |||
| res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data, | |||
| ema, ema_decay, quant_min, quant_max, training, kernel_name) | |||
| res_list = minmax_update_perlayer_compute(input_data, min_data, max_data, ema, ema_decay) | |||
| with tvm.target.cce(): | |||
| sch = generic.auto_schedule(res_list) | |||
| @@ -21,12 +21,12 @@ from ..._checkparam import Rel | |||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | |||
| from ...common import dtype as mstype | |||
| __all__ = ["FakeQuantPerLayer", | |||
| __all__ = ["MinMaxUpdatePerLayer", | |||
| "MinMaxUpdatePerChannel", | |||
| "FakeQuantPerLayer", | |||
| "FakeQuantPerLayerGrad", | |||
| "FakeQuantPerChannel", | |||
| "FakeQuantPerChannelGrad", | |||
| "MinMaxUpdatePerLayer", | |||
| "MinMaxUpdatePerChannel", | |||
| "BatchNormFold", | |||
| "BatchNormFoldGrad", | |||
| "CorrectionMul", | |||
| @@ -38,10 +38,128 @@ __all__ = ["FakeQuantPerLayer", | |||
| "BatchNormFoldGradD", | |||
| "BatchNormFold2_D", | |||
| "BatchNormFold2GradD", | |||
| "BatchNormFold2GradReduce", | |||
| "BatchNormFold2GradReduce" | |||
| ] | |||
| class MinMaxUpdatePerLayer(PrimitiveWithInfer): | |||
| r""" | |||
| Update min and max per layer. | |||
| Args: | |||
| ema (bool): Use EMA algorithm update value min and max. Default: False. | |||
| ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. | |||
| Inputs: | |||
| - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. | |||
| - **min** (Tensor) : Value of the min range of the input data x. | |||
| - **max** (Tensor) : Value of the max range of the input data x. | |||
| Outputs: | |||
| - Tensor: Simulate quantize tensor of x. | |||
| Examples: | |||
| >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) | |||
| >>> min_tensor = Tensor(np.array([-6]), mstype.float32) | |||
| >>> max_tensor = Tensor(np.array([6]), mstype.float32) | |||
| >>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor) | |||
| """ | |||
| support_quant_bit = [4, 7, 8] | |||
| @prim_attr_register | |||
| def __init__(self, ema=False, ema_decay=0.999): | |||
| """init FakeQuantMinMaxPerLayerUpdate OP""" | |||
| if context.get_context('device_target') == "Ascend": | |||
| from mindspore.ops._op_impl._custom_op import minmax_update_perlayer | |||
| if ema and not ema_decay: | |||
| raise ValueError( | |||
| f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | |||
| self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | |||
| self.ema_decay = validator.check_number_range( | |||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.init_prim_io_names(inputs=['x', 'min', 'max'], | |||
| outputs=['min_up', 'max_up']) | |||
| def infer_shape(self, x_shape, min_shape, max_shape): | |||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | |||
| validator.check("min shape", min_shape, "max shape", | |||
| max_shape, Rel.EQ, self.name) | |||
| validator.check_integer("min shape", len( | |||
| min_shape), 1, Rel.EQ, self.name) | |||
| return min_shape, max_shape | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"min": min_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"max": max_type}, valid_types, self.name) | |||
| return min_type, max_type | |||
| class MinMaxUpdatePerChannel(PrimitiveWithInfer): | |||
| r""" | |||
| Update min and max per channel. | |||
| Args: | |||
| ema (bool): Use EMA algorithm update value min and max. Default: False. | |||
| ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. | |||
| channel_axis (int): Channel asis for per channel compute. Default: 1. | |||
| Inputs: | |||
| - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. | |||
| - **min** (Tensor) : Value of the min range of the input data x. | |||
| - **max** (Tensor) : Value of the max range of the input data x. | |||
| Outputs: | |||
| - Tensor: Simulate quantize tensor of x. | |||
| Examples: | |||
| >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) | |||
| >>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) | |||
| >>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) | |||
| >>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max) | |||
| """ | |||
| support_quant_bit = [4, 7, 8] | |||
| @prim_attr_register | |||
| def __init__(self, ema=False, ema_decay=0.999, channel_axis=1): | |||
| """init FakeQuantPerChannelUpdate OP for Ascend""" | |||
| if context.get_context('device_target') == "Ascend": | |||
| from mindspore.ops._op_impl._custom_op import minmax_update_perchannel | |||
| if ema and not ema_decay: | |||
| raise ValueError( | |||
| f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | |||
| self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | |||
| self.ema_decay = validator.check_number_range( | |||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.channel_axis = validator.check_integer( | |||
| 'channel axis', channel_axis, 0, Rel.GE, self.name) | |||
| self.init_prim_io_names( | |||
| inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up']) | |||
| def infer_shape(self, x_shape, min_shape, max_shape): | |||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) | |||
| validator.check("min shape", min_shape, "max shape", | |||
| max_shape, Rel.EQ, self.name) | |||
| validator.check_integer("min shape", len( | |||
| min_shape), 1, Rel.EQ, self.name) | |||
| return min_shape, max_shape | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same( | |||
| {"x": x_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"min": min_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"max": max_type}, valid_types, self.name) | |||
| return min_type, max_type | |||
| class FakeQuantPerLayer(PrimitiveWithInfer): | |||
| r""" | |||
| Simulate the quantize and dequantize operations in training time. | |||
| @@ -832,153 +950,3 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer): | |||
| def infer_dtype(self, dout_type, x_type): | |||
| validator.check("dout type", dout_type, "x type", x_type) | |||
| return dout_type, dout_type | |||
| class MinMaxUpdatePerLayer(PrimitiveWithInfer): | |||
| r""" | |||
| Update min and max value for fake quant per layer op. | |||
| Args: | |||
| num_bits (int) : Number bits for quantization aware. Default: 8. | |||
| ema (bool): Use EMA algorithm update value min and max. Default: False. | |||
| ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| training (bool): Training the network or not. Default: True. | |||
| Inputs: | |||
| - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. | |||
| - **min** (Tensor) : Value of the min range of the input data x. | |||
| - **max** (Tensor) : Value of the max range of the input data x. | |||
| Outputs: | |||
| - Tensor: Simulate quantize tensor of x. | |||
| Examples: | |||
| >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) | |||
| >>> min_tensor = Tensor(np.array([-6]), mstype.float32) | |||
| >>> max_tensor = Tensor(np.array([6]), mstype.float32) | |||
| >>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor) | |||
| """ | |||
| support_quant_bit = [4, 7, 8] | |||
| @prim_attr_register | |||
| def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, | |||
| training=True): | |||
| """init MinMaxUpdatePerLayer OP""" | |||
| if context.get_context('device_target') == "Ascend": | |||
| from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update | |||
| if num_bits not in self.support_quant_bit: | |||
| raise ValueError( | |||
| f"For '{self.name}' attr \'num_bits\' is not support.") | |||
| if ema and not ema_decay: | |||
| raise ValueError( | |||
| f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | |||
| self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | |||
| self.symmetric = validator.check_value_type( | |||
| 'symmetric', symmetric, (bool,), self.name) | |||
| self.narrow_range = validator.check_value_type( | |||
| 'narrow_range', narrow_range, (bool,), self.name) | |||
| self.training = validator.check_value_type( | |||
| 'training', training, (bool,), self.name) | |||
| self.ema_decay = validator.check_number_range( | |||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.num_bits = validator.check_integer( | |||
| 'num_bits', num_bits, 0, Rel.GT, self.name) | |||
| self.init_prim_io_names(inputs=['x', 'min', 'max'], | |||
| outputs=['min_up', 'max_up']) | |||
| def infer_shape(self, x_shape, min_shape, max_shape): | |||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) | |||
| validator.check("min shape", min_shape, "max shape", | |||
| max_shape, Rel.EQ, self.name) | |||
| validator.check_integer("min shape", len( | |||
| min_shape), 1, Rel.EQ, self.name) | |||
| return min_shape, max_shape | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"min": min_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"max": max_type}, valid_types, self.name) | |||
| return min_type, max_type | |||
| class MinMaxUpdatePerChannel(PrimitiveWithInfer): | |||
| r""" | |||
| Update min and max value for fake quant per layer op. | |||
| Args: | |||
| num_bits (int) : Number bits for quantization aware. Default: 8. | |||
| ema (bool): Use EMA algorithm update value min and max. Default: False. | |||
| ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. | |||
| symmetric (bool): Quantization algorithm use symmetric or not. Default: False. | |||
| narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. | |||
| training (bool): Training the network or not. Default: True. | |||
| channel_axis (int): Channel asis for per channel compute. Default: 1. | |||
| Inputs: | |||
| - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. | |||
| - **min** (Tensor) : Value of the min range of the input data x. | |||
| - **max** (Tensor) : Value of the max range of the input data x. | |||
| Outputs: | |||
| - Tensor: Simulate quantize tensor of x. | |||
| Examples: | |||
| >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) | |||
| >>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) | |||
| >>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) | |||
| >>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max) | |||
| """ | |||
| support_quant_bit = [4, 7, 8] | |||
| @prim_attr_register | |||
| def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False, | |||
| training=True, channel_axis=1): | |||
| """init MinMaxUpdatePerChannel OP for Ascend""" | |||
| if context.get_context('device_target') == "Ascend": | |||
| from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update | |||
| if num_bits not in self.support_quant_bit: | |||
| raise ValueError( | |||
| f"For '{self.name}' attr \'num_bits\' is not support.") | |||
| if ema and not ema_decay: | |||
| raise ValueError( | |||
| f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | |||
| self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | |||
| self.symmetric = validator.check_value_type( | |||
| 'symmetric', symmetric, (bool,), self.name) | |||
| self.narrow_range = validator.check_value_type( | |||
| 'narrow_range', narrow_range, (bool,), self.name) | |||
| self.training = validator.check_value_type( | |||
| 'training', training, (bool,), self.name) | |||
| self.ema_decay = validator.check_number_range( | |||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.num_bits = validator.check_integer( | |||
| 'num_bits', num_bits, 0, Rel.GT, self.name) | |||
| self.channel_axis = validator.check_integer( | |||
| 'channel axis', channel_axis, 0, Rel.GE, self.name) | |||
| self.init_prim_io_names( | |||
| inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up']) | |||
| def infer_shape(self, x_shape, min_shape, max_shape): | |||
| validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) | |||
| validator.check("min shape", min_shape, "max shape", | |||
| max_shape, Rel.EQ, self.name) | |||
| validator.check_integer("min shape", len( | |||
| min_shape), 1, Rel.EQ, self.name) | |||
| return min_shape, max_shape | |||
| def infer_dtype(self, x_type, min_type, max_type): | |||
| valid_types = (mstype.float16, mstype.float32) | |||
| validator.check_tensor_type_same( | |||
| {"x": x_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"min": min_type}, valid_types, self.name) | |||
| validator.check_tensor_type_same( | |||
| {"max": max_type}, valid_types, self.name) | |||
| return min_type, max_type | |||