| @@ -20,7 +20,6 @@ | |||
| #include <thrust/reduce.h> | |||
| #include <thrust/pair.h> | |||
| #include "fake_quant_perchannel_impl.cuh" | |||
| #include "device/gpu/cuda_common.h" | |||
| /** | |||
| * Find the nudge min, max and scale value as output. | |||
| @@ -34,13 +33,17 @@ | |||
| * @param channel_num | |||
| * @return | |||
| */ | |||
| __global__ void NudgeMinMaxPerChannel(const float *input_min, const float *input_max, const float quant_min, | |||
| const float quant_max, float *nudge_min, float *nudge_max, float *scale, | |||
| int channel_num) { | |||
| __global__ void NudgeMinMaxPerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max, | |||
| float *nudge_min, float *nudge_max, float *scale, int channel_num, | |||
| const bool symmetric) { | |||
| float zp_from_min = 0.f; | |||
| float nudge_zp = 0.f; | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_num; i += blockDim.x * gridDim.x) { | |||
| if (symmetric) { | |||
| input_max[i] = abs(input_min[0]) < input_max[i] ? input_max[i] : -input_min[i]; | |||
| input_min[i] = abs(input_min[i]) < input_max[i] ? -input_max[i] : input_min[i]; | |||
| } | |||
| if ((quant_max - quant_min) == 0 || (input_max[i] - input_min[i]) == 0) { | |||
| scale[i] = 0.f; | |||
| zp_from_min = 0.f; | |||
| @@ -62,11 +65,11 @@ __global__ void NudgeMinMaxPerChannel(const float *input_min, const float *input | |||
| } | |||
| } | |||
| void CalNudgePerChannel(const float *input_min, const float *input_max, const float quant_min, const float quant_max, | |||
| float *nudge_min, float *nudge_max, float *scale, const int channel_num, | |||
| void CalNudgePerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max, | |||
| float *nudge_min, float *nudge_max, float *scale, const int channel_num, const bool symmetric, | |||
| cudaStream_t cuda_stream) { | |||
| NudgeMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>( | |||
| input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, channel_num); | |||
| input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, channel_num, symmetric); | |||
| } | |||
| /** | |||
| @@ -80,9 +83,8 @@ void CalNudgePerChannel(const float *input_min, const float *input_max, const fl | |||
| * @param scale - array | |||
| * @return | |||
| */ | |||
| __global__ void FakeQuantizePerChannel(const float *input, float *output, const int total_size, const int channel_size, | |||
| const float *nudge_min, const float *nudge_max, const float *scale, | |||
| bool symmetric) { | |||
| __global__ void FakeQuantPerChannel(const float *input, float *output, const int total_size, const int channel_size, | |||
| const float *nudge_min, const float *nudge_max, const float *scale) { | |||
| float input_x = 0.f; | |||
| int nudge_input = 0; | |||
| int channel_idx = 0; | |||
| @@ -106,16 +108,15 @@ __global__ void FakeQuantizePerChannel(const float *input, float *output, const | |||
| } | |||
| } | |||
| void CalFakeQuantizePerChannel(const float *input, float *output, const int total_size, const int channel_size, | |||
| const float *nudge_min, const float *nudge_max, const float *scale, bool symmetric, | |||
| cudaStream_t cuda_stream) { | |||
| FakeQuantizePerChannel<<<GET_BLOCKS(total_size), GET_THREADS, 0, cuda_stream>>>( | |||
| input, output, total_size, channel_size, nudge_min, nudge_max, scale, symmetric); | |||
| void CalFakeQuantPerChannel(const float *input, float *output, const int total_size, const int channel_size, | |||
| const float *nudge_min, const float *nudge_max, const float *scale, | |||
| cudaStream_t cuda_stream) { | |||
| FakeQuantPerChannel<<<GET_BLOCKS(total_size), GET_THREADS, 0, cuda_stream>>>(input, output, total_size, channel_size, | |||
| nudge_min, nudge_max, scale); | |||
| } | |||
| __global__ void FakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output, | |||
| const int total_size, const int channel_size, const float *nudge_min, | |||
| const float *nudge_max) { | |||
| __global__ void FakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_size, | |||
| const int channel_size, const float *nudge_min, const float *nudge_max) { | |||
| int channel_idx = 0; | |||
| int per_channel_num = total_size / channel_size; | |||
| @@ -129,9 +130,9 @@ __global__ void FakeQuantizePerChannelGrad(const float *input, const float *grad | |||
| } | |||
| } | |||
| void CalFakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output, const int total_num, | |||
| const int channel_num, const float *nudge_min, const float *nudge_max, | |||
| cudaStream_t cuda_stream) { | |||
| FakeQuantizePerChannelGrad<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>( | |||
| input, gradient, output, total_num, channel_num, nudge_min, nudge_max); | |||
| void CalFakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_num, | |||
| const int channel_num, const float *nudge_min, const float *nudge_max, | |||
| cudaStream_t cuda_stream) { | |||
| FakeQuantPerChannelGrad<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(input, gradient, output, total_num, | |||
| channel_num, nudge_min, nudge_max); | |||
| } | |||
| @@ -14,22 +14,21 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ | |||
| void CalNudgePerChannel(const float* input_min, const float* input_max, const float quant_min, const float quant_max, | |||
| float* nudge_min, float* nudge_max, float* scale, const int channel_num, | |||
| cudaStream_t cuda_stream); | |||
| #include "device/gpu/cuda_common.h" | |||
| void CalFakeQuantizePerChannel(const float* input, float* output, const int total_num, const int channel_num, | |||
| const float* nudge_min, const float* nudge_max, const float* scale, bool symmetric, | |||
| cudaStream_t cuda_stream); | |||
| void CalNudgePerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max, | |||
| float *nudge_min, float *nudge_max, float *scale, const int channel_num, const bool symmetric, | |||
| cudaStream_t cuda_stream); | |||
| void CalMinMaxPerChannel(float* input, float* input_min, float* input_max, const int total_num, const int channel_num, | |||
| const float ema_decay, const bool ema, cudaStream_t cuda_stream); | |||
| void CalFakeQuantPerChannel(const float *input, float *output, const int total_num, const int channel_num, | |||
| const float *nudge_min, const float *nudge_max, const float *scale, | |||
| cudaStream_t cuda_stream); | |||
| void CalFakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output, const int total_num, | |||
| const int channel_num, const float* nudge_min, const float* nudge_max, | |||
| cudaStream_t cuda_stream); | |||
| void CalFakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_num, | |||
| const int channel_num, const float *nudge_min, const float *nudge_max, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ | |||
| @@ -17,11 +17,10 @@ | |||
| #include <thrust/extrema.h> | |||
| #include <thrust/device_vector.h> | |||
| #include <thrust/pair.h> | |||
| #include "device/gpu/cuda_common.h" | |||
| #include "fake_quant_perlayer_impl.cuh" | |||
| __global__ void FakeQuantize(const float *input, float *output, const int size, const float *nudge_min, | |||
| const float *nudge_max, const float *scale) { | |||
| __global__ void FakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min, | |||
| const float *nudge_max, const float *scale) { | |||
| float input_x = 0.f; | |||
| int nudge_input = 0; | |||
| @@ -43,8 +42,8 @@ __global__ void FakeQuantize(const float *input, float *output, const int size, | |||
| return; | |||
| } | |||
| __global__ void FakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size, | |||
| const float *nudge_min, const float *nudge_max) { | |||
| __global__ void FakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size, | |||
| const float *nudge_min, const float *nudge_max) { | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { | |||
| if (input[i] < nudge_min[0] || input[i] > nudge_max[0]) { | |||
| output[i] = 0; | |||
| @@ -55,12 +54,18 @@ __global__ void FakeQuantizeGrad(const float *input, const float *gradient, floa | |||
| return; | |||
| } | |||
| __global__ void NudgeMinMax(const float *input_min, const float *input_max, const float quant_min, | |||
| const float quant_max, float *nudge_min, float *nudge_max, float *scale) { | |||
| __global__ void NudgeMinMaxPerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max, | |||
| float *nudge_min, float *nudge_max, float *scale, const bool symmetric) { | |||
| float zp_from_min = 0.f; | |||
| scale[0] = 0.f; | |||
| nudge_max[0] = 0.f; | |||
| nudge_min[0] = 0.f; | |||
| if (symmetric) { | |||
| input_max[0] = abs(input_min[0]) < input_max[0] ? input_max[0] : -input_min[0]; | |||
| input_min[0] = abs(input_min[0]) < input_max[0] ? -input_max[0] : input_min[0]; | |||
| } | |||
| if ((quant_max - quant_min) == 0 || (input_max[0] - input_min[0]) == 0) { | |||
| scale[0] = 0.f; | |||
| zp_from_min = 0.f; | |||
| @@ -83,53 +88,24 @@ __global__ void NudgeMinMax(const float *input_min, const float *input_max, cons | |||
| return; | |||
| } | |||
| __global__ void UpdateInputMinMaxWithEMA(float *input_min, float *input_max, const float min, const float max, | |||
| const float decay) { | |||
| input_min[0] = decay * (min) + (1 - decay) * (input_min[0]); | |||
| input_min[0] = input_min[0] > 0 ? 0 : input_min[0]; | |||
| input_max[0] = decay * (max) + (1 - decay) * (input_max[0]); | |||
| input_max[0] = input_max[0] < 0 ? 0 : input_max[0]; | |||
| return; | |||
| } | |||
| __global__ void UpdateInputMinMax(float *input_min, float *input_max, const float min, const float max) { | |||
| input_min[0] = min > 0 ? 0 : min; | |||
| input_max[0] = max < 0 ? 0 : max; | |||
| } | |||
| void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max, | |||
| const float *scale, bool symmetric, cudaStream_t cuda_stream) { | |||
| FakeQuantize<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, output, size, nudge_min, nudge_max, scale); | |||
| return; | |||
| } | |||
| void CalFakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size, | |||
| const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream) { | |||
| FakeQuantizeGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, gradient, output, size, nudge_min, | |||
| nudge_max); | |||
| void CalFakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min, | |||
| const float *nudge_max, const float *scale, cudaStream_t cuda_stream) { | |||
| FakeQuantPerLayer<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, output, size, nudge_min, nudge_max, | |||
| scale); | |||
| return; | |||
| } | |||
| void CalNudge(const float *input_min, const float *input_max, const float quant_min, const float quant_max, | |||
| float *nudge_min, float *nudge_max, float *scale, cudaStream_t cuda_stream) { | |||
| NudgeMinMax<<<1, 1, 0, cuda_stream>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale); | |||
| void CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size, | |||
| const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream) { | |||
| FakeQuantPerLayerGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, gradient, output, size, nudge_min, | |||
| nudge_max); | |||
| return; | |||
| } | |||
| void CalMinMax(float *input, float *input_min, float *input_max, const int size, const float ema_decay, const bool ema, | |||
| cudaStream_t cuda_stream) { | |||
| float minel = 0.f; | |||
| float maxel = 0.f; | |||
| auto policy = thrust::cuda::par.on(cuda_stream); | |||
| thrust::pair<thrust::device_ptr<float>, thrust::device_ptr<float>> tuple; | |||
| tuple = thrust::minmax_element(policy, thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + size); | |||
| minel = tuple.first[0]; | |||
| maxel = tuple.second[0]; | |||
| if (ema) { | |||
| UpdateInputMinMaxWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, minel, maxel, ema_decay); | |||
| } else { | |||
| UpdateInputMinMax<<<1, 1, 0, cuda_stream>>>(input_min, input_max, minel, maxel); | |||
| } | |||
| void CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max, | |||
| float *nudge_min, float *nudge_max, float *scale, const bool symmetric, | |||
| cudaStream_t cuda_stream) { | |||
| NudgeMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, | |||
| symmetric); | |||
| return; | |||
| } | |||
| @@ -14,19 +14,18 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ | |||
| void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max, | |||
| const float *scale, bool symmetric, cudaStream_t cuda_stream); | |||
| #include "device/gpu/cuda_common.h" | |||
| void CalFakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size, | |||
| const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream); | |||
| void CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max, | |||
| float *nudge_min, float *nudge_max, float *scale, const bool symmetric, cudaStream_t cuda_stream); | |||
| void CalNudge(const float *input_min, const float *input_max, const float quant_min, const float quant_max, | |||
| float *nudge_min, float *nudge_max, float *scale, cudaStream_t cuda_stream); | |||
| void CalFakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min, | |||
| const float *nudge_max, const float *scale, cudaStream_t cuda_stream); | |||
| void CalMinMax(float *input, float *input_min, float *input_max, const int size, const float ema_decay, const bool ema, | |||
| cudaStream_t cuda_stream); | |||
| void CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size, | |||
| const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ | |||
| @@ -102,9 +102,9 @@ void FakeQuantPerChannelGpuKernel::InitSizeLists() { | |||
| void FakeQuantPerChannelGpuKernel::CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, | |||
| float *nudge_min, float *nudge_max, float *scale, void *stream_ptr) { | |||
| CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale, | |||
| symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFakeQuantPerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs, | |||
| @@ -119,9 +119,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp | |||
| int total_size = input_size_ / sizeof(float); | |||
| if (global_step_ >= quant_delay_) { | |||
| CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFakeQuantPerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| @@ -117,10 +117,10 @@ bool FakeQuantPerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs, c | |||
| // control flow for quant_delay | |||
| if (global_step_ >= quant_delay_) { | |||
| // real launch | |||
| CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| @@ -129,10 +129,10 @@ bool FakeQuantPerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs, c | |||
| global_step_++; | |||
| } else { | |||
| // real launch | |||
| CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } | |||
| return true; | |||
| @@ -115,10 +115,10 @@ bool FakeQuantPerLayerGradGpuKernel::Launch(const std::vector<AddressPtr> &input | |||
| } | |||
| if (global_step_ >= quant_delay_) { | |||
| CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFakeQuantizeGrad(input, gradient, output, quant_num_, nudge_min, nudge_max, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CalFakeQuantPerLayerGrad(input, gradient, output, quant_num_, nudge_min, nudge_max, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| } else { | |||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| @@ -150,7 +150,7 @@ class ConvertToQuantNetwork: | |||
| prefix = name | |||
| add_quant = _AddFakeQuantAfterSubCell(prim_op, | |||
| num_bits=self.act_bits, | |||
| quant_delay=self.act_delay, | |||
| quant_delay=self.act_qdelay, | |||
| per_channel=self.act_channel, | |||
| symmetric=self.act_symmetric, | |||
| narrow_range=self.act_range) | |||
| @@ -408,19 +408,19 @@ def convert_quant_network(network, | |||
| Args: | |||
| network (Cell): Obtain a pipeline through network for saving graph summary. | |||
| quant_delay (int or tuple): Number of steps after which weights and activations are quantized during | |||
| eval. The first element represent weights and second element represent data flow. Default: (0, 0) | |||
| bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False. | |||
| freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0. | |||
| num_bits (int or tuple): Number of bits to use for quantizing weights and activations. The first | |||
| quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during | |||
| eval. The first element represent weights and second element represent data flow. Default: (0, 0) | |||
| num_bits (int, list or tuple): Number of bits to use for quantizing weights and activations. The first | |||
| element represent weights and second element represent data flow. Default: (8, 8) | |||
| per_channel (int or tuple): Quantization granularity based on layer or on channel. If `True` | |||
| per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True` | |||
| then base on per channel otherwise base on per layer. The first element represent weights | |||
| and second element represent data flow. Default: (False, False) | |||
| symmetric (int or tuple): Quantization algorithm use symmetric or not. If `True` then base on | |||
| symmetric otherwise base on assymmetric. The first element represent weights and second | |||
| symmetric (bool, list or tuple): Quantization algorithm use symmetric or not. If `True` then base on | |||
| symmetric otherwise base on asymmetric. The first element represent weights and second | |||
| element represent data flow. Default: (False, False) | |||
| narrow_range (int or tuple): Quantization algorithm use narrow range or not. If `True` then base | |||
| narrow_range (bool, list or tuple): Quantization algorithm use narrow range or not. If `True` then base | |||
| on narrow range otherwise base on off narrow range. The first element represent weights and | |||
| second element represent data flow. Default: (False, False) | |||
| @@ -0,0 +1,625 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore import nn | |||
| from mindspore.ops.operations import _quant_ops as Q | |||
| context.set_context(device_target='GPU', device_id=0) | |||
| class Net(nn.Cell): | |||
| def __init__(self, num_bits=8, symmetric=False, narrow_range=False, channel_axis=1): | |||
| super(Net, self).__init__() | |||
| self.op = Q.FakeQuantPerChannel(num_bits=num_bits, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range, | |||
| channel_axis=channel_axis) | |||
| def construct(self, x, minq, maxq): | |||
| return self.op(x, minq, maxq) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel1(): | |||
| # WithVarsPerChannel_ZeroMinAndMax | |||
| x = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32) | |||
| min_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32) | |||
| max_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32) | |||
| expect = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False, channel_axis=0) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel2(): | |||
| # WithVarsPerChannelDim1NudgedDown_RegularRange | |||
| # scale 1/4, zp 0.4, nudge 0. nudged ranges [0.0, 63.75] | |||
| x = np.array([-0.1, 0.0, 63.75, 63.8]).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32) | |||
| max_val = np.array([63.65, 63.65, 63.65, 63.65]).astype(np.float32) | |||
| expect = np.array([0.0, 0.0, 63.75, 63.75]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False, channel_axis=0) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel3(): | |||
| # WithVarsPerChannelDim1NudgedDown_NarrowRange | |||
| # scale 1/4, zp 1.4, nudge 1. nudged ranges[0.0, 63.5] | |||
| x = np.array([-0.1, 0.0, 63.5, 63.6]).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32) | |||
| max_val = np.array([63.4, 63.4, 63.4, 63.4]).astype(np.float32) | |||
| expect = np.array([0.0, 0.0, 63.5, 63.5]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True, channel_axis=0) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel4(): | |||
| # WithVarsPerChannelDim1NudgedUp_RegularRange | |||
| # [-0.125, 63.625] | |||
| # scale 1/4, zp: 0.5, nudge 0. nudged range [-0.25, 63.5] | |||
| x = np.array([-0.26, -0.25, -0.24, 63.6]).astype(np.float32) | |||
| expect = np.array([-0.25, -0.25, -0.25, 63.5]).astype(np.float32) | |||
| min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32) | |||
| max_val = np.array([63.625, 63.625, 63.625, 63.625]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False, channel_axis=0) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel5(): | |||
| # WithVarsPerChannelDim1NudgedUp_NarrowRange | |||
| # scale 1/4, zp: 1.5, nudge 2. nudged range [-0.25, 63.25] | |||
| x = np.array([-0.26, -0.25, -0.24, 63.3]).astype(np.float32) | |||
| expect = np.array([-0.25, -0.25, -0.25, 63.25]).astype(np.float32) | |||
| min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32) | |||
| max_val = np.array([63.375, 63.375, 63.375, 63.375]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True, channel_axis=0) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel6(): | |||
| # WithVarsPerChannelDim2NudgedDown_RegularRange | |||
| # scale 1/4, zp: 0.4, nudge 0. nudged range [-0.25, 63.75] | |||
| x = np.array([-0.1, 0.0, 0.1, 0.25, 63.75, 63.80] | |||
| ).reshape(2, 3).astype(np.float32) | |||
| expect = np.array([-0.0, 0.0, 0.0, 0.25, 63.75, 63.75]).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1]).reshape(3).astype(np.float32) | |||
| max_val = np.array([63.65, 63.65, 63.65]).reshape(3).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False, channel_axis=1) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel7(): | |||
| # WithVarsPerChannelDim2NudgedDown_NarrowRange | |||
| # scale 1/4, zp: 1.4, nudge 1. nudged range [-0.25, 63.5] | |||
| x = np.array([-0.1, 0.0, 0.1, 0.25, 63.5, 63.6] | |||
| ).reshape(2, 3).astype(np.float32) | |||
| expect = np.array([0.0, 0.0, 0.0, 0.25, 63.5, 63.5]).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1]).reshape(3).astype(np.float32) | |||
| max_val = np.array([63.4, 63.4, 63.4]).reshape(3).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True, channel_axis=1) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel8(): | |||
| # WithVarsPerChannelDim2NudgedUp_RegularRange | |||
| # scale 1/4, zp: 0.5, nudge 1. nudged range [-0.25, 63.5] | |||
| x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6] | |||
| ).reshape(2, 3).astype(np.float32) | |||
| expect = np.array([-0.25, -0.25, -0.25, 0.0, 63.5, 63.5] | |||
| ).astype(np.float32) | |||
| min_val = np.array([-0.125, -0.125, -0.125]).reshape(3).astype(np.float32) | |||
| max_val = np.array([63.625, 63.625, 63.625]).reshape(3).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False, channel_axis=1) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel9(): | |||
| # WithVarsPerChannelDim2NudgedUp_NarrowRange | |||
| # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] | |||
| x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3] | |||
| ).reshape(2, 3).astype(np.float32) | |||
| expect = np.array( | |||
| [-0.25, -0.25, -0.25, 0.0, 63.25, 63.25]).astype(np.float32) | |||
| min_val = np.array([-0.125, -0.125, -0.125]).reshape(3).astype(np.float32) | |||
| max_val = np.array([63.375, 63.375, 63.375]).reshape(3).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True, channel_axis=1) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel10(): | |||
| # WithVarsPerChannelDim4NudgedDown_RegularRange | |||
| # scale 1/4, zp: 0.4, nudge 0. nudged range [-0.25, 63.25] | |||
| x = np.array([-0.1, 0.0, 0.1, 0.25, 0.5, 0.75, | |||
| 1.0, 1.25, 1.5, 1.75, 2.0, 2.25, | |||
| 63.0, 63.25, 63.5, 63.7, 63.75, 63.8, | |||
| 63.9, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32) | |||
| expect = np.array([0.0, 0.0, 0.0, 0.25, 0.5, 0.75, | |||
| 1.0, 1.25, 1.5, 1.75, 2.0, 2.25, | |||
| 63.0, 63.25, 63.5, 63.75, 63.75, 63.75, | |||
| 63.75, 63.75, 63.75, 63.75, 63.75, 63.75]).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32) | |||
| max_val = np.array([63.65, 63.65, 63.65, 63.65] | |||
| ).reshape(4).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False, channel_axis=1) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel11(): | |||
| # WithVarsPerChannelDim4NudgedDown_NarrowRange | |||
| # scale 1/4, zp: 1.4, nudge 1. nudged range [0.0, 63.25] | |||
| x = np.array([-0.1, 0.0, 0.1, 0.25, 0.5, 0.75, | |||
| 1.0, 1.25, 1.5, 1.75, 2.0, 2.25, | |||
| 63.0, 63.25, 63.3, 63.4, 63.5, 63.6, | |||
| 63.7, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32) | |||
| expect = np.array([0.0, 0.0, 0.0, 0.25, 0.5, 0.75, | |||
| 1.0, 1.25, 1.5, 1.75, 2.0, 2.25, | |||
| 63.0, 63.25, 63.25, 63.5, 63.5, 63.5, | |||
| 63.5, 63.5, 63.5, 63.5, 63.5, 63.5]).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32) | |||
| max_val = np.array([63.4, 63.4, 63.4, 63.4]).reshape(4).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True, channel_axis=1) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel12(): | |||
| # WithVarsPerChannelDim4NudgedUp_RegularRange | |||
| # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] | |||
| x = np.array([-0.3, -0.25, -0.2, 0.0, 0.25, 0.5, | |||
| 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, | |||
| 63.0, 63.25, 63.4, 63.5, 63.6, 63.7, | |||
| 100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32) | |||
| expect = np.array([-0.25, -0.25, -0.25, 0.0, 0.25, 0.5, | |||
| 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, | |||
| 63.0, 63.25, 63.5, 63.5, 63.5, 63.5, | |||
| 63.5, 63.5, 63.5, 63.5, 63.5, 63.5]).astype(np.float32) | |||
| min_val = np.array([-0.125, -0.125, -0.125, -0.125] | |||
| ).reshape(4).astype(np.float32) | |||
| max_val = np.array([63.625, 63.625, 63.625, 63.625] | |||
| ).reshape(4).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False, channel_axis=1) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel13(): | |||
| # WithVarsPerChannelDim4NudgedUp_NarrowRange | |||
| # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] | |||
| x = np.array([-0.3, -0.25, -0.2, 0.0, 0.25, 0.5, | |||
| 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, | |||
| 63.0, 63.2, 63.25, 63.3, 63.4, 63.5, | |||
| 100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32) | |||
| expect = np.array([-0.25, -0.25, -0.25, 0.0, 0.25, 0.5, | |||
| 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, | |||
| 63.0, 63.25, 63.25, 63.25, 63.25, 63.25, | |||
| 63.25, 63.25, 63.25, 63.25, 63.25, 63.25]).astype(np.float32) | |||
| min_val = np.array([-0.125, -0.125, -0.125, -0.125] | |||
| ).reshape(4).astype(np.float32) | |||
| max_val = np.array([63.375, 63.375, 63.375, 63.375] | |||
| ).reshape(4).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True, channel_axis=1) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel14(): | |||
| # WithVarsPerChannelDim1NudgedDown_4Bits_RegularRange | |||
| # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] | |||
| x = np.array([-0.1, 0.0, 7.5, 7.6]).reshape(4).astype(np.float32) | |||
| expect = np.array([0.0, 0.0, 7.5, 7.5]).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32) | |||
| max_val = np.array([7.4, 7.4, 7.4, 7.4]).reshape(4).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=False, channel_axis=0) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel15(): | |||
| # WithVarsPerChannelDim1NudgedDown_4Bits_NarrowRange | |||
| # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] | |||
| x = np.array([-0.1, 0.0, 7.0, 7.1]).reshape(4).astype(np.float32) | |||
| expect = np.array([0.0, 0.0, 7.0, 7.0]).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32) | |||
| max_val = np.array([6.9, 6.9, 6.9, 6.9]).reshape(4).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=True, channel_axis=0) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel16(): | |||
| # WithVarsPerChannelDim1NudgedUp_4Bits_RegularRange | |||
| # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] | |||
| x = np.array([-0.6, -0.5, 7.0, 7.1]).reshape(4).astype(np.float32) | |||
| expect = np.array([-0.5, -0.5, 7.0, 7.0]).astype(np.float32) | |||
| min_val = np.array([-0.4, -0.4, -0.4, -0.4]).reshape(4).astype(np.float32) | |||
| max_val = np.array([7.1, 7.1, 7.1, 7.1]).reshape(4).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=False, channel_axis=0) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel17(): | |||
| # WithVarsPerChannelDim1NudgedUp_4Bits_NarrowRange | |||
| # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] | |||
| x = np.array([-0.6, -0.5, 6.5, 6.6]).reshape(4).astype(np.float32) | |||
| expect = np.array([-0.5, -0.5, 6.5, 6.5]).astype(np.float32) | |||
| min_val = np.array([-0.4, -0.4, -0.4, -0.4]).reshape(4).astype(np.float32) | |||
| max_val = np.array([6.6, 6.6, 6.6, 6.6]).reshape(4).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=True, channel_axis=0) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel18(): | |||
| # WithVarsPerChannelDim2NudgedDown_4Bits_RegularRange | |||
| # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] | |||
| x = np.array([-0.1, 0.0, 0.1, 0.5, 7.5, 7.6] | |||
| ).reshape(2, 3).astype(np.float32) | |||
| expect = np.array([0.0, 0.0, 0.0, 0.5, 7.5, 7.5]).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1]).reshape(3).astype(np.float32) | |||
| max_val = np.array([7.4, 7.4, 7.4]).reshape(3).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=False, channel_axis=1) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel19(): | |||
| # WithVarsPerChannelDim2NudgedDown_4Bits_NarrowRange | |||
| # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] | |||
| x = np.array([-0.1, 0.0, 0.1, 0.5, 7.0, 7.1] | |||
| ).reshape(2, 3).astype(np.float32) | |||
| expect = np.array([0.0, 0.0, 0.0, 0.5, 7.0, 7.0]).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1]).reshape(3).astype(np.float32) | |||
| max_val = np.array([6.9, 6.9, 6.9]).reshape(3).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=True, channel_axis=1) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel20(): | |||
| # WithVarsPerChannelDim2NudgedUp_4Bits_RegularRange | |||
| # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] | |||
| x = np.array([-0.51, -0.5, -0.24, 0.0, 7.0, 7.1] | |||
| ).reshape(2, 3).astype(np.float32) | |||
| expect = np.array([-0.5, -0.5, 0.0, 0.0, 7.0, 7.0]).astype(np.float32) | |||
| min_val = np.array([-0.4, -0.4, -0.4]).reshape(3).astype(np.float32) | |||
| max_val = np.array([7.1, 7.1, 7.1]).reshape(3).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=False, channel_axis=1) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel21(): | |||
| # WithVarsPerChannelDim2NudgedUp_4Bits_NarrowRange | |||
| # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] | |||
| x = np.array([-0.6, -0.5, -0.24, 0.0, 6.5, 6.6] | |||
| ).reshape(2, 3).astype(np.float32) | |||
| expect = np.array([-0.5, -0.5, 0.0, 0.0, 6.5, 6.5]).astype(np.float32) | |||
| min_val = np.array([-0.4, -0.4, -0.4]).reshape(3).astype(np.float32) | |||
| max_val = np.array([6.6, 6.6, 6.6]).reshape(3).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=True, channel_axis=1) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel22(): | |||
| # WithVarsPerChannelDim4NudgedDown_4Bits_RegularRange | |||
| # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] | |||
| x = np.array([-0.1, 0.0, 0.1, 0.5, 1.0, 1.5, | |||
| 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, | |||
| 6.0, 6.5, 7.0, 7.4, 7.5, 7.7, | |||
| 7.8, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32) | |||
| expect = np.array([0.0, 0.0, 0.0, 0.5, 1.0, 1.5, | |||
| 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, | |||
| 6.0, 6.5, 7.0, 7.5, 7.5, 7.5, | |||
| 7.5, 7.5, 7.5, 7.5, 7.5, 7.5]).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32) | |||
| max_val = np.array([7.4, 7.4, 7.4, 7.4]).reshape(4).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=False, channel_axis=1) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel23(): | |||
| # WithVarsPerChannelDim4NudgedDown_4Bits_NarrowRange | |||
| # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] | |||
| x = np.array([-0.1, 0.0, 0.1, 0.5, 1.0, 1.5, | |||
| 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, | |||
| 6.0, 6.5, 6.8, 6.9, 7.0, 7.1, | |||
| 7.2, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32) | |||
| expect = np.array([0.0, 0.0, 0.0, 0.5, 1.0, 1.5, | |||
| 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, | |||
| 6.0, 6.5, 7.0, 7.0, 7.0, 7.0, | |||
| 7.0, 7.0, 7.0, 7.0, 7.0, 7.0]).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32) | |||
| max_val = np.array([6.9, 6.9, 6.9, 6.9]).reshape(4).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=True, channel_axis=1) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel24(): | |||
| # WithVarsPerChannelDim4NudgedUp_4Bits_RegularRange | |||
| # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] | |||
| x = np.array([-0.6, -0.5, -0.4, 0.0, 0.5, 1.0, | |||
| 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, | |||
| 6.0, 6.5, 6.9, 7.0, 7.1, 7.7, | |||
| 100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32) | |||
| expect = np.array([-0.5, -0.5, -0.5, 0.0, 0.5, 1.0, | |||
| 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, | |||
| 6.0, 6.5, 7.0, 7.0, 7.0, 7.0, | |||
| 7.0, 7.0, 7.0, 7.0, 7.0, 7.0]).astype(np.float32) | |||
| min_val = np.array([-0.4, -0.4, -0.4, -0.4]).reshape(4).astype(np.float32) | |||
| max_val = np.array([7.1, 7.1, 7.1, 7.1]).reshape(4).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=False, channel_axis=1) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_perchannel25(): | |||
| # WithVarsPerChannelDim4NudgedUp_4Bits_NarrowRange | |||
| # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] | |||
| x = np.array([-0.6, -0.5, -0.4, 0.0, 0.5, 1.0, | |||
| 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, | |||
| 5.5, 6.0, 6.4, 6.5, 6.6, 6.7, | |||
| 100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32) | |||
| expect = np.array([-0.5, -0.5, -0.5, 0.0, 0.5, 1.0, | |||
| 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, | |||
| 5.5, 6.0, 6.5, 6.5, 6.5, 6.5, | |||
| 6.5, 6.5, 6.5, 6.5, 6.5, 6.5]).astype(np.float32) | |||
| min_val = np.array([-0.4, -0.4, -0.4, -0.4]).reshape(4).astype(np.float32) | |||
| max_val = np.array([6.6, 6.6, 6.6, 6.6]).reshape(4).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=True, channel_axis=1) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @@ -0,0 +1,373 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import Tensor | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| from mindspore.ops.operations import _quant_ops as Q | |||
| context.set_context(device_target='GPU', device_id=0) | |||
| class Net(nn.Cell): | |||
| def __init__(self, num_bits=8, narrow_range=False): | |||
| super(Net, self).__init__() | |||
| self.op = Q.FakeQuantPerChannelGrad( | |||
| num_bits=num_bits, narrow_range=narrow_range) | |||
| def construct(self, dout, x, minq, maxq): | |||
| return self.op(dout, x, minq, maxq) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad1(): | |||
| # WithVarsPerChannelDim1GradientNudgedDown_ZeroMinAndMax | |||
| dout = np.random.uniform(-1, 1, size=[4]).astype('float32') | |||
| x = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32) | |||
| min_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32) | |||
| max_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32) | |||
| expect = dout | |||
| net = Net(num_bits=8, narrow_range=False) | |||
| output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("=" * 40) | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad2(): | |||
| # WithVarsPerChannelDim1GradientNudgedDown_RegularRange | |||
| dout = np.random.uniform(-1, 1, size=[4]).astype('float32') | |||
| x = np.array([-0.1, 0.0, 63.75, 63.8]).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32) | |||
| max_val = np.array([63.65, 63.65, 63.65, 63.65]).astype(np.float32) | |||
| expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False) | |||
| output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("=" * 40) | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad3(): | |||
| # WithVarsPerChannelDim1GradientNudgedDown_NarrowRange | |||
| dout = np.random.uniform(-1, 1, size=[4]).astype('float32') | |||
| x = np.array([-0.1, 0.0, 63.5, 63.6]).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32) | |||
| max_val = np.array([63.4, 63.4, 63.4, 63.4]).astype(np.float32) | |||
| expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True) | |||
| output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("=" * 40) | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad4(): | |||
| # WithVarsPerChannelDim1GradientNudgedUp_RegularRange | |||
| dout = np.random.uniform(-1, 1, size=[4]).astype('float32') | |||
| x = np.array([-0.3, -0.25, 63.5, 63.6]).astype(np.float32) | |||
| min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32) | |||
| max_val = np.array([63.625, 63.625, 63.625, 63.625]).astype(np.float32) | |||
| expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False) | |||
| output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("=" * 40) | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad5(): | |||
| # WithVarsPerChannelDim1GradientNudgedUp_NarrowRange | |||
| dout = np.random.uniform(-1, 1, size=[4]).astype('float32') | |||
| x = np.array([-0.3, -0.25, 63.25, 63.3]).astype(np.float32) | |||
| min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32) | |||
| max_val = np.array([63.375, 63.375, 63.375, 63.375]).astype(np.float32) | |||
| expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True) | |||
| output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("=" * 40) | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad6(): | |||
| # WithVarsPerChannelDim2GradientNudgedDown_RegularRange | |||
| read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32') | |||
| x = np.array([-0.1, 0.0, 0.1, 0.25, 63.75, 63.8] | |||
| ).reshape(3, 2).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1]).astype(np.float32) | |||
| max_val = np.array([63.65, 63.65, 63.65]).astype(np.float32) | |||
| dout = read_dout.flatten() | |||
| expect = np.array([0.0, dout[1], dout[2], dout[3], | |||
| dout[4], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True) | |||
| output = net(Tensor(read_dout), Tensor( | |||
| x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("=" * 40) | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad7(): | |||
| # WithVarsPerChannelDim2GradientNudgedDown_NarrowRange | |||
| read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32') | |||
| x = np.array([-0.1, 0.0, 0.1, 0.25, 63.5, 63.6] | |||
| ).reshape(3, 2).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1]).astype(np.float32) | |||
| max_val = np.array([63.4, 63.4, 63.4]).astype(np.float32) | |||
| dout = read_dout.flatten() | |||
| expect = np.array([0.0, dout[1], dout[2], dout[3], | |||
| dout[4], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True) | |||
| output = net(Tensor(read_dout), Tensor( | |||
| x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("=" * 40) | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad8(): | |||
| # WithVarsPerChannelDim2GradientNudgedUp_RegularRange | |||
| read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32') | |||
| x = np.array([-0.3, -0.25, -0.2, 0.0, 63.5, 63.6] | |||
| ).reshape(3, 2).astype(np.float32) | |||
| min_val = np.array([-0.125, -0.125, -0.125]).astype(np.float32) | |||
| max_val = np.array([63.625, 63.625, 63.625]).astype(np.float32) | |||
| dout = read_dout.flatten() | |||
| expect = np.array([0.0, dout[1], dout[2], dout[3], | |||
| dout[4], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False) | |||
| output = net(Tensor(read_dout), Tensor( | |||
| x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("=" * 40) | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad9(): | |||
| # WithVarsPerChannelDim2GradientNudgedUp_NarrowRange | |||
| read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32') | |||
| x = np.array([-0.3, -0.25, -0.2, 0.0, 63.25, 63.3] | |||
| ).reshape(3, 2).astype(np.float32) | |||
| min_val = np.array([-0.125, -0.125, -0.125]).astype(np.float32) | |||
| max_val = np.array([63.375, 63.375, 63.375]).astype(np.float32) | |||
| dout = read_dout.flatten() | |||
| expect = np.array([0.0, dout[1], dout[2], dout[3], | |||
| dout[4], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True) | |||
| output = net(Tensor(read_dout), Tensor( | |||
| x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("=" * 40) | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad10(): | |||
| # WithVarsPerChannelDim4GradientNudgedDown_RegularRange | |||
| read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32') | |||
| x = np.array([-0.1, 0.0, 63.75, 63.8, -0.1, 0.0, | |||
| 63.75, 63.8, -0.1, 0.0, 63.75, 63.8, | |||
| -0.1, 0.0, 63.75, 63.8, -0.1, 0.0, | |||
| 63.75, 63.8, -0.1, 0.0, 63.75, 63.8]).reshape(4, 3, 2, 1).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32) | |||
| max_val = np.array([63.65, 63.65, 63.65, 63.65]).astype(np.float32) | |||
| dout = read_dout.flatten() | |||
| expect = np.array([0.0, dout[1], dout[2], 0.0, | |||
| 0.0, dout[5], dout[6], 0.0, | |||
| 0.0, dout[9], dout[10], 0.0, | |||
| 0.0, dout[13], dout[14], 0.0, | |||
| 0.0, dout[17], dout[18], 0.0, | |||
| 0.0, dout[21], dout[22], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False) | |||
| output = net(Tensor(read_dout), Tensor( | |||
| x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("=" * 40) | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad11(): | |||
| # WithVarsPerChannelDim4GradientNudgedDown_NarrowRange | |||
| read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32') | |||
| x = np.array([-0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, | |||
| 63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6]).reshape(4, 3, 2, 1).astype(np.float32) | |||
| min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32) | |||
| max_val = np.array([63.4, 63.4, 63.4, 63.4]).astype(np.float32) | |||
| dout = read_dout.flatten() | |||
| expect = np.array([0.0, dout[1], dout[2], 0.0, | |||
| 0.0, dout[5], dout[6], 0.0, | |||
| 0.0, dout[9], dout[10], 0.0, | |||
| 0.0, dout[13], dout[14], 0.0, | |||
| 0.0, dout[17], dout[18], 0.0, | |||
| 0.0, dout[21], dout[22], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True) | |||
| output = net(Tensor(read_dout), Tensor( | |||
| x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("=" * 40) | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad12(): | |||
| # WithVarsPerChannelDim4GradientNudgedUp_RegularRange | |||
| read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32') | |||
| x = np.array([-0.3, -0.25, 63.5, 63.6, -0.3, -0.25, | |||
| 63.5, 63.6, -0.3, -0.25, 63.5, 63.6, | |||
| -0.3, -0.25, 63.5, 63.6, -0.3, -0.25, | |||
| 63.5, 63.6, -0.3, -0.25, 63.5, 63.6]).reshape(4, 3, 2, 1).astype(np.float32) | |||
| min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32) | |||
| max_val = np.array([63.625, 63.625, 63.625, 63.625]).astype(np.float32) | |||
| dout = read_dout.flatten() | |||
| expect = np.array([0.0, dout[1], dout[2], 0.0, | |||
| 0.0, dout[5], dout[6], 0.0, | |||
| 0.0, dout[9], dout[10], 0.0, | |||
| 0.0, dout[13], dout[14], 0.0, | |||
| 0.0, dout[17], dout[18], 0.0, | |||
| 0.0, dout[21], dout[22], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False) | |||
| output = net(Tensor(read_dout), Tensor( | |||
| x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("=" * 40) | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad13(): | |||
| # WithVarsPerChannelDim4GradientNudgedUp_NarrowRange | |||
| read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32') | |||
| x = np.array([-0.3, -0.25, 63.25, 63.3, -0.3, -0.25, | |||
| 63.25, 63.3, -0.3, -0.25, 63.25, 63.3, | |||
| -0.3, -0.25, 63.25, 63.3, -0.3, -0.25, | |||
| 63.25, 63.3, -0.3, -0.25, 63.25, 63.3]).reshape(4, 3, 2, 1).astype(np.float32) | |||
| min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32) | |||
| max_val = np.array([63.375, 63.375, 63.375, 63.375]).astype(np.float32) | |||
| dout = read_dout.flatten() | |||
| expect = np.array([0.0, dout[1], dout[2], 0.0, | |||
| 0.0, dout[5], dout[6], 0.0, | |||
| 0.0, dout[9], dout[10], 0.0, | |||
| 0.0, dout[13], dout[14], 0.0, | |||
| 0.0, dout[17], dout[18], 0.0, | |||
| 0.0, dout[21], dout[22], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True) | |||
| output = net(Tensor(read_dout), Tensor( | |||
| x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("=" * 40) | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @@ -0,0 +1,386 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore.common.tensor import Tensor | |||
| import mindspore.nn as nn | |||
| from mindspore.ops.operations import _quant_ops as Q | |||
| context.set_context(device_target='GPU', device_id=0) | |||
| class Net(nn.Cell): | |||
| def __init__(self, | |||
| num_bits=8, | |||
| quant_delay=0, | |||
| symmetric=False, | |||
| narrow_range=False, | |||
| training=True): | |||
| super(Net, self).__init__() | |||
| self.fake_quant = Q.FakeQuantPerLayer(num_bits=num_bits, | |||
| quant_delay=quant_delay, | |||
| symmetric=symmetric, | |||
| narrow_range=narrow_range, | |||
| training=training) | |||
| def construct(self, x, minq, maxq): | |||
| return self.fake_quant(x, minq, maxq) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant1(): | |||
| # (8, false, 0.0f, 0.0f, TensorShape({2, 3}), | |||
| # {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, | |||
| # {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); | |||
| x = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).reshape(2, 3).astype(np.float32) | |||
| min_val = np.array([0]).reshape(1).astype(np.float32) | |||
| max_val = np.array([0]).reshape(1).astype(np.float32) | |||
| expect = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant2(): | |||
| # 8, false, -10.0f, 53.75f, TensorShape({2, 3}), | |||
| # {-10.1f, -10.0f, -9.9f, -9.75f, 53.75f, 53.8f}, | |||
| # {-10.0f, -10.0f, -10.0f, -9.75f, 53.75f, 53.75f}); | |||
| x = np.array([-10.1, -10.0, -9.9, -9.75, 53.75, 53.8]).reshape(2, 3).astype(np.float32) | |||
| min_val = np.array([-10.0]).reshape(1).astype(np.float32) | |||
| max_val = np.array([53.75]).reshape(1).astype(np.float32) | |||
| expect = np.array([-10.0, -10.0, -10.0, -9.75, 53.75, 53.75]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant3(): | |||
| # WithVarsNoNudging_NarrowRange | |||
| x = np.array([-10.1, -10.0, -9.90, -9.75, 53.5, 53.6]).reshape(2, 3).astype(np.float32) | |||
| min_val = np.array([-10.0]).reshape(1).astype(np.float32) | |||
| max_val = np.array([53.5]).reshape(1).astype(np.float32) | |||
| expect = np.array([-10.0, -10.0, -10.0, -9.75, 53.5, 53.5]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant4(): | |||
| # WithVarsNudgedDown_RegularRange | |||
| x = np.array([-0.1, 0.0, 0.1, 0.25, 63.75, 63.8]).reshape(2, 3).astype(np.float32) | |||
| min_val = np.array([-0.1]).reshape(1).astype(np.float32) | |||
| max_val = np.array([63.65]).reshape(1).astype(np.float32) | |||
| expect = np.array([-0.0, 0.0, 0.0, 0.25, 63.75, 63.75]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant5(): | |||
| # WithVarsNudgedDown_NarrowRange | |||
| x = np.array([-0.1, 0.0, 0.1, 0.25, 63.5, 63.6]).reshape(2, 3).astype(np.float32) | |||
| min_val = np.array([-0.1]).reshape(1).astype(np.float32) | |||
| max_val = np.array([63.4]).reshape(1).astype(np.float32) | |||
| expect = np.array([-0.0, 0.0, 0.0, 0.25, 63.5, 63.5]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant6(): | |||
| # WithVarsNudgedUp_RegularRange | |||
| x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6]).reshape(2, 3).astype(np.float32) | |||
| min_val = np.array([-0.125]).reshape(1).astype(np.float32) | |||
| max_val = np.array([63.625]).reshape(1).astype(np.float32) | |||
| expect = np.array([-0.25, -0.25, -0.25, 0.0, 63.5, 63.5]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant7(): | |||
| # WithVarsNudgedUp_NarrowRange | |||
| x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3]).reshape(2, 3).astype(np.float32) | |||
| min_val = np.array([-0.125]).reshape(1).astype(np.float32) | |||
| max_val = np.array([63.375]).reshape(1).astype(np.float32) | |||
| expect = np.array([-0.25, -0.25, -0.25, 0.0, 63.25, 63.25]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant8(): | |||
| # WithVarsNudgedZeroIs255_RegularRange | |||
| x = np.array([-63.80, -63.75, -63.70, -63.5, 0.0, 0.1]).reshape(2, 3).astype(np.float32) | |||
| min_val = np.array([-63.65]).reshape(1).astype(np.float32) | |||
| max_val = np.array([0.1]).reshape(1).astype(np.float32) | |||
| expect = np.array([-63.75, -63.75, -63.75, -63.5, 0.0, 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant9(): | |||
| # WithVarsNudgedZeroIs255_NarrowRange | |||
| x = np.array([-63.6, -63.5, -63.4, -63.25, 0.0, 0.1]).reshape(2, 3).astype(np.float32) | |||
| min_val = np.array([-63.4]).reshape(1).astype(np.float32) | |||
| max_val = np.array([0.1]).reshape(1).astype(np.float32) | |||
| expect = np.array([-63.5, -63.5, -63.5, -63.25, 0.0, 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant10(): | |||
| # WithVarsNoNudging_4Bits_RegularRange | |||
| x = np.array([-6.1, -6.0, -5.9, -5.5, 1.5, 1.6]).reshape(2, 3).astype(np.float32) | |||
| min_val = np.array([-6.0]).reshape(1).astype(np.float32) | |||
| max_val = np.array([1.5]).reshape(1).astype(np.float32) | |||
| expect = np.array([-6.0, -6.0, -6.0, -5.5, 1.5, 1.5]).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=False) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant11(): | |||
| # WithVarsNoNudging_4Bits_NarrowRange | |||
| x = np.array([-6.1, -6.0, -5.9, -5.5, 1.0, 1.1]).reshape(2, 3).astype(np.float32) | |||
| min_val = np.array([-6.0]).reshape(1).astype(np.float32) | |||
| max_val = np.array([1.0]).reshape(1).astype(np.float32) | |||
| expect = np.array([-6.0, -6.0, -6.0, -5.5, 1.0, 1.0]).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=True) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant12(): | |||
| # WithVarsNudgedDown_4Bits_RegularRange | |||
| x = np.array([-0.1, 0.0, 0.1, 0.5, 7.5, 7.6]).reshape(2, 3).astype(np.float32) | |||
| min_val = np.array([-0.1]).reshape(1).astype(np.float32) | |||
| max_val = np.array([7.4]).reshape(1).astype(np.float32) | |||
| expect = np.array([-0.0, 0.0, 0.0, 0.5, 7.5, 7.5]).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=False) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant13(): | |||
| # WithVarsNudgedDown_4Bits_NarrowRange | |||
| x = np.array([-0.1, 0.0, 0.1, 0.5, 7.0, 7.1]).reshape(2, 3).astype(np.float32) | |||
| min_val = np.array([-0.1]).reshape(1).astype(np.float32) | |||
| max_val = np.array([6.9]).reshape(1).astype(np.float32) | |||
| expect = np.array([-0.0, 0.0, 0.0, 0.5, 7.0, 7.0]).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=True) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant14(): | |||
| # WithVarsNudgedUp_4Bits_RegularRange | |||
| x = np.array([-0.6, -0.5, -0.24, 0.0, 7.0, 7.1]).reshape(2, 3).astype(np.float32) | |||
| min_val = np.array([-0.4]).reshape(1).astype(np.float32) | |||
| max_val = np.array([7.1]).reshape(1).astype(np.float32) | |||
| expect = np.array([-0.5, -0.5, -0.00, 0.0, 7.0, 7.0]).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=False) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant15(): | |||
| # WithVarsNudgedUp_4Bits_NarrowRange | |||
| x = np.array([-0.6, -0.5, -0.24, 0.0, 6.5, 6.6]).reshape(2, 3).astype(np.float32) | |||
| min_val = np.array([-0.4]).reshape(1).astype(np.float32) | |||
| max_val = np.array([6.6]).reshape(1).astype(np.float32) | |||
| expect = np.array([-0.5, -0.5, -0.00, 0.0, 6.5, 6.5]).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=True) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant16(): | |||
| # WithVarsNudgedZero15_4Bits_RegularRange | |||
| x = np.array([-7.6, -7.5, -7.4, -7.2, 0.0, 0.1]).reshape(2, 3).astype(np.float32) | |||
| min_val = np.array([-7.3]).reshape(1).astype(np.float32) | |||
| max_val = np.array([0.2]).reshape(1).astype(np.float32) | |||
| expect = np.array([-7.5, -7.5, -7.5, -7.0, 0.0, 0.0]).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=False) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant17(): | |||
| # WithVarsNudgedZero15_4Bits_NarrowRange | |||
| x = np.array([-7.1, -7.0, -6.9, -6.5, 0.0, 0.1]).reshape(2, 3).astype(np.float32) | |||
| min_val = np.array([-6.8]).reshape(1).astype(np.float32) | |||
| max_val = np.array([0.2]).reshape(1).astype(np.float32) | |||
| expect = np.array([-7.0, -7.0, -7.0, -6.5, 0.0, 0.0]).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=True) | |||
| output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @@ -0,0 +1,221 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| from mindspore import Tensor | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| from mindspore.ops.operations import _quant_ops as Q | |||
| context.set_context(device_target='GPU', device_id=0) | |||
| class Net(nn.Cell): | |||
| def __init__(self, num_bits=8, narrow_range=False): | |||
| super(Net, self).__init__() | |||
| self.op = Q.FakeQuantPerLayerGrad(num_bits=num_bits, narrow_range=narrow_range) | |||
| def construct(self, dout, x, minq, maxq): | |||
| return self.op(dout, x, minq, maxq) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad1(): | |||
| # WithArgsGradient RegularRange | |||
| dout = np.random.uniform(-1, 1, size=[6]).astype('float32') | |||
| x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6]).astype(np.float32) | |||
| min_val = np.array([-0.125]).reshape(1).astype(np.float32) | |||
| max_val = np.array([63.625]).reshape(1).astype(np.float32) | |||
| expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False) | |||
| output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad2(): | |||
| # WithArgsGradient NarrowRange | |||
| dout = np.random.uniform(-1, 1, size=[6]).astype('float32') | |||
| x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3]).astype(np.float32) | |||
| min_val = np.array([-0.125]).reshape(1).astype(np.float32) | |||
| max_val = np.array([63.375]).reshape(1).astype(np.float32) | |||
| expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True) | |||
| output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad3(): | |||
| # WithArgsGradient_4Bits_RegularRange | |||
| dout = np.random.uniform(-1, 1, size=[6]).astype('float32') | |||
| x = np.array([-0.6, -0.5, -0.4, 0.0, 7.0, 7.1]).astype(np.float32) | |||
| min_val = np.array([-0.4]).reshape(1).astype(np.float32) | |||
| max_val = np.array([7.1]).reshape(1).astype(np.float32) | |||
| expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=False) | |||
| output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad4(): | |||
| # WithArgsGradient_4Bits_NarrowRange | |||
| dout = np.random.uniform(-1, 1, size=[6]).astype('float32') | |||
| x = np.array([-0.6, -0.5, -0.4, 0.0, 6.5, 6.6]).astype(np.float32) | |||
| min_val = np.array([-0.4]).reshape(1).astype(np.float32) | |||
| max_val = np.array([6.6]).reshape(1).astype(np.float32) | |||
| expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=True) | |||
| output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad5(): | |||
| # FakeQuantWithMinMaxVarsGradient | |||
| dout = np.random.uniform(-1, 1, size=[6]).astype('float32') | |||
| x = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).astype(np.float32) | |||
| min_val = np.array([0.0]).reshape(1).astype(np.float32) | |||
| max_val = np.array([0.0]).reshape(1).astype(np.float32) | |||
| expect = dout | |||
| net = Net(num_bits=8, narrow_range=True) | |||
| output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad6(): | |||
| # WithVarsGradient_RegularRange | |||
| dout = np.random.uniform(-1, 1, size=[6]).astype('float32') | |||
| x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6]).astype(np.float32) | |||
| min_val = np.array([-0.125]).reshape(1).astype(np.float32) | |||
| max_val = np.array([63.625]).reshape(1).astype(np.float32) | |||
| expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=False) | |||
| output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad7(): | |||
| # WithVarsGradient_NarrowRange | |||
| dout = np.random.uniform(-1, 1, size=[6]).astype('float32') | |||
| x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3]).astype(np.float32) | |||
| min_val = np.array([-0.125]).reshape(1).astype(np.float32) | |||
| max_val = np.array([63.375]).reshape(1).astype(np.float32) | |||
| expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=8, narrow_range=True) | |||
| output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad8(): | |||
| # WithVarsGradient_4Bits_RegularRange | |||
| dout = np.random.uniform(-1, 1, size=[6]).astype('float32') | |||
| x = np.array([-0.6, -0.5, -0.4, 0.0, 7.0, 7.1]).astype(np.float32) | |||
| min_val = np.array([-0.4]).reshape(1).astype(np.float32) | |||
| max_val = np.array([7.1]).reshape(1).astype(np.float32) | |||
| expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=False) | |||
| output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_fake_quant_grad9(): | |||
| # WithVarsGradient_4Bits_NarrowRange | |||
| dout = np.random.uniform(-1, 1, size=[6]).astype('float32') | |||
| x = np.array([-0.6, -0.5, -0.4, 0.0, 6.5, 6.6]).astype(np.float32) | |||
| min_val = np.array([-0.4]).reshape(1).astype(np.float32) | |||
| max_val = np.array([6.6]).reshape(1).astype(np.float32) | |||
| expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32) | |||
| net = Net(num_bits=4, narrow_range=True) | |||
| output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) | |||
| error = np.ones(shape=expect.shape) * 1.0e-5 | |||
| diff = output.asnumpy().flatten() - expect | |||
| print("output: ", output) | |||
| print("expect: ", expect) | |||
| assert np.all(np.abs(diff) < error) | |||