Browse Source

add fake quant test case for gpu

tags/v0.6.0-beta
chenzomi 5 years ago
parent
commit
8873f9dc7e
13 changed files with 1703 additions and 123 deletions
  1. +24
    -23
      mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cu
  2. +13
    -14
      mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh
  3. +25
    -49
      mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cu
  4. +10
    -11
      mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh
  5. +3
    -3
      mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc
  6. +3
    -3
      mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc
  7. +8
    -8
      mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.cc
  8. +4
    -4
      mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc
  9. +8
    -8
      mindspore/train/quant/quant.py
  10. +625
    -0
      tests/st/ops/gpu/test_fake_quant_perchannel.py
  11. +373
    -0
      tests/st/ops/gpu/test_fake_quant_perchannel_grad.py
  12. +386
    -0
      tests/st/ops/gpu/test_fake_quant_perlayer.py
  13. +221
    -0
      tests/st/ops/gpu/test_fake_quant_perlayer_grad.py

+ 24
- 23
mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cu View File

@@ -20,7 +20,6 @@
#include <thrust/reduce.h> #include <thrust/reduce.h>
#include <thrust/pair.h> #include <thrust/pair.h>
#include "fake_quant_perchannel_impl.cuh" #include "fake_quant_perchannel_impl.cuh"
#include "device/gpu/cuda_common.h"
/** /**
* Find the nudge min, max and scale value as output. * Find the nudge min, max and scale value as output.
@@ -34,13 +33,17 @@
* @param channel_num * @param channel_num
* @return * @return
*/ */
__global__ void NudgeMinMaxPerChannel(const float *input_min, const float *input_max, const float quant_min,
const float quant_max, float *nudge_min, float *nudge_max, float *scale,
int channel_num) {
__global__ void NudgeMinMaxPerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, int channel_num,
const bool symmetric) {
float zp_from_min = 0.f; float zp_from_min = 0.f;
float nudge_zp = 0.f; float nudge_zp = 0.f;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_num; i += blockDim.x * gridDim.x) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_num; i += blockDim.x * gridDim.x) {
if (symmetric) {
input_max[i] = abs(input_min[0]) < input_max[i] ? input_max[i] : -input_min[i];
input_min[i] = abs(input_min[i]) < input_max[i] ? -input_max[i] : input_min[i];
}
if ((quant_max - quant_min) == 0 || (input_max[i] - input_min[i]) == 0) { if ((quant_max - quant_min) == 0 || (input_max[i] - input_min[i]) == 0) {
scale[i] = 0.f; scale[i] = 0.f;
zp_from_min = 0.f; zp_from_min = 0.f;
@@ -62,11 +65,11 @@ __global__ void NudgeMinMaxPerChannel(const float *input_min, const float *input
} }
} }
void CalNudgePerChannel(const float *input_min, const float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, const int channel_num,
void CalNudgePerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, const int channel_num, const bool symmetric,
cudaStream_t cuda_stream) { cudaStream_t cuda_stream) {
NudgeMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>( NudgeMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(
input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, channel_num);
input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, channel_num, symmetric);
} }
/** /**
@@ -80,9 +83,8 @@ void CalNudgePerChannel(const float *input_min, const float *input_max, const fl
* @param scale - array * @param scale - array
* @return * @return
*/ */
__global__ void FakeQuantizePerChannel(const float *input, float *output, const int total_size, const int channel_size,
const float *nudge_min, const float *nudge_max, const float *scale,
bool symmetric) {
__global__ void FakeQuantPerChannel(const float *input, float *output, const int total_size, const int channel_size,
const float *nudge_min, const float *nudge_max, const float *scale) {
float input_x = 0.f; float input_x = 0.f;
int nudge_input = 0; int nudge_input = 0;
int channel_idx = 0; int channel_idx = 0;
@@ -106,16 +108,15 @@ __global__ void FakeQuantizePerChannel(const float *input, float *output, const
} }
} }
void CalFakeQuantizePerChannel(const float *input, float *output, const int total_size, const int channel_size,
const float *nudge_min, const float *nudge_max, const float *scale, bool symmetric,
cudaStream_t cuda_stream) {
FakeQuantizePerChannel<<<GET_BLOCKS(total_size), GET_THREADS, 0, cuda_stream>>>(
input, output, total_size, channel_size, nudge_min, nudge_max, scale, symmetric);
void CalFakeQuantPerChannel(const float *input, float *output, const int total_size, const int channel_size,
const float *nudge_min, const float *nudge_max, const float *scale,
cudaStream_t cuda_stream) {
FakeQuantPerChannel<<<GET_BLOCKS(total_size), GET_THREADS, 0, cuda_stream>>>(input, output, total_size, channel_size,
nudge_min, nudge_max, scale);
} }
__global__ void FakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output,
const int total_size, const int channel_size, const float *nudge_min,
const float *nudge_max) {
__global__ void FakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_size,
const int channel_size, const float *nudge_min, const float *nudge_max) {
int channel_idx = 0; int channel_idx = 0;
int per_channel_num = total_size / channel_size; int per_channel_num = total_size / channel_size;
@@ -129,9 +130,9 @@ __global__ void FakeQuantizePerChannelGrad(const float *input, const float *grad
} }
} }
void CalFakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output, const int total_num,
const int channel_num, const float *nudge_min, const float *nudge_max,
cudaStream_t cuda_stream) {
FakeQuantizePerChannelGrad<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(
input, gradient, output, total_num, channel_num, nudge_min, nudge_max);
void CalFakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_num,
const int channel_num, const float *nudge_min, const float *nudge_max,
cudaStream_t cuda_stream) {
FakeQuantPerChannelGrad<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(input, gradient, output, total_num,
channel_num, nudge_min, nudge_max);
} }

+ 13
- 14
mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh View File

@@ -14,22 +14,21 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_
void CalNudgePerChannel(const float* input_min, const float* input_max, const float quant_min, const float quant_max,
float* nudge_min, float* nudge_max, float* scale, const int channel_num,
cudaStream_t cuda_stream);
#include "device/gpu/cuda_common.h"
void CalFakeQuantizePerChannel(const float* input, float* output, const int total_num, const int channel_num,
const float* nudge_min, const float* nudge_max, const float* scale, bool symmetric,
cudaStream_t cuda_stream);
void CalNudgePerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, const int channel_num, const bool symmetric,
cudaStream_t cuda_stream);
void CalMinMaxPerChannel(float* input, float* input_min, float* input_max, const int total_num, const int channel_num,
const float ema_decay, const bool ema, cudaStream_t cuda_stream);
void CalFakeQuantPerChannel(const float *input, float *output, const int total_num, const int channel_num,
const float *nudge_min, const float *nudge_max, const float *scale,
cudaStream_t cuda_stream);
void CalFakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output, const int total_num,
const int channel_num, const float* nudge_min, const float* nudge_max,
cudaStream_t cuda_stream);
void CalFakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_num,
const int channel_num, const float *nudge_min, const float *nudge_max,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_

+ 25
- 49
mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cu View File

@@ -17,11 +17,10 @@
#include <thrust/extrema.h> #include <thrust/extrema.h>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/pair.h> #include <thrust/pair.h>
#include "device/gpu/cuda_common.h"
#include "fake_quant_perlayer_impl.cuh" #include "fake_quant_perlayer_impl.cuh"
__global__ void FakeQuantize(const float *input, float *output, const int size, const float *nudge_min,
const float *nudge_max, const float *scale) {
__global__ void FakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min,
const float *nudge_max, const float *scale) {
float input_x = 0.f; float input_x = 0.f;
int nudge_input = 0; int nudge_input = 0;
@@ -43,8 +42,8 @@ __global__ void FakeQuantize(const float *input, float *output, const int size,
return; return;
} }
__global__ void FakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size,
const float *nudge_min, const float *nudge_max) {
__global__ void FakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size,
const float *nudge_min, const float *nudge_max) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
if (input[i] < nudge_min[0] || input[i] > nudge_max[0]) { if (input[i] < nudge_min[0] || input[i] > nudge_max[0]) {
output[i] = 0; output[i] = 0;
@@ -55,12 +54,18 @@ __global__ void FakeQuantizeGrad(const float *input, const float *gradient, floa
return; return;
} }
__global__ void NudgeMinMax(const float *input_min, const float *input_max, const float quant_min,
const float quant_max, float *nudge_min, float *nudge_max, float *scale) {
__global__ void NudgeMinMaxPerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, const bool symmetric) {
float zp_from_min = 0.f; float zp_from_min = 0.f;
scale[0] = 0.f; scale[0] = 0.f;
nudge_max[0] = 0.f; nudge_max[0] = 0.f;
nudge_min[0] = 0.f; nudge_min[0] = 0.f;
if (symmetric) {
input_max[0] = abs(input_min[0]) < input_max[0] ? input_max[0] : -input_min[0];
input_min[0] = abs(input_min[0]) < input_max[0] ? -input_max[0] : input_min[0];
}
if ((quant_max - quant_min) == 0 || (input_max[0] - input_min[0]) == 0) { if ((quant_max - quant_min) == 0 || (input_max[0] - input_min[0]) == 0) {
scale[0] = 0.f; scale[0] = 0.f;
zp_from_min = 0.f; zp_from_min = 0.f;
@@ -83,53 +88,24 @@ __global__ void NudgeMinMax(const float *input_min, const float *input_max, cons
return; return;
} }
__global__ void UpdateInputMinMaxWithEMA(float *input_min, float *input_max, const float min, const float max,
const float decay) {
input_min[0] = decay * (min) + (1 - decay) * (input_min[0]);
input_min[0] = input_min[0] > 0 ? 0 : input_min[0];
input_max[0] = decay * (max) + (1 - decay) * (input_max[0]);
input_max[0] = input_max[0] < 0 ? 0 : input_max[0];
return;
}
__global__ void UpdateInputMinMax(float *input_min, float *input_max, const float min, const float max) {
input_min[0] = min > 0 ? 0 : min;
input_max[0] = max < 0 ? 0 : max;
}
void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max,
const float *scale, bool symmetric, cudaStream_t cuda_stream) {
FakeQuantize<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, output, size, nudge_min, nudge_max, scale);
return;
}
void CalFakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size,
const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream) {
FakeQuantizeGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, gradient, output, size, nudge_min,
nudge_max);
void CalFakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min,
const float *nudge_max, const float *scale, cudaStream_t cuda_stream) {
FakeQuantPerLayer<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, output, size, nudge_min, nudge_max,
scale);
return; return;
} }
void CalNudge(const float *input_min, const float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, cudaStream_t cuda_stream) {
NudgeMinMax<<<1, 1, 0, cuda_stream>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale);
void CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size,
const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream) {
FakeQuantPerLayerGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, gradient, output, size, nudge_min,
nudge_max);
return; return;
} }
void CalMinMax(float *input, float *input_min, float *input_max, const int size, const float ema_decay, const bool ema,
cudaStream_t cuda_stream) {
float minel = 0.f;
float maxel = 0.f;
auto policy = thrust::cuda::par.on(cuda_stream);
thrust::pair<thrust::device_ptr<float>, thrust::device_ptr<float>> tuple;
tuple = thrust::minmax_element(policy, thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + size);
minel = tuple.first[0];
maxel = tuple.second[0];
if (ema) {
UpdateInputMinMaxWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, minel, maxel, ema_decay);
} else {
UpdateInputMinMax<<<1, 1, 0, cuda_stream>>>(input_min, input_max, minel, maxel);
}
void CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, const bool symmetric,
cudaStream_t cuda_stream) {
NudgeMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale,
symmetric);
return; return;
} }

+ 10
- 11
mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh View File

@@ -14,19 +14,18 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_
void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max,
const float *scale, bool symmetric, cudaStream_t cuda_stream);
#include "device/gpu/cuda_common.h"
void CalFakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size,
const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream);
void CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, const bool symmetric, cudaStream_t cuda_stream);
void CalNudge(const float *input_min, const float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, cudaStream_t cuda_stream);
void CalFakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min,
const float *nudge_max, const float *scale, cudaStream_t cuda_stream);
void CalMinMax(float *input, float *input_min, float *input_max, const int size, const float ema_decay, const bool ema,
cudaStream_t cuda_stream);
void CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size,
const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_

+ 3
- 3
mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc View File

@@ -102,9 +102,9 @@ void FakeQuantPerChannelGpuKernel::InitSizeLists() {
void FakeQuantPerChannelGpuKernel::CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, void FakeQuantPerChannelGpuKernel::CalFakeQuantize(float *input, float *output, float *input_min, float *input_max,
float *nudge_min, float *nudge_max, float *scale, void *stream_ptr) { float *nudge_min, float *nudge_max, float *scale, void *stream_ptr) {
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale,
symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantPerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
} }
bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs, bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,


+ 3
- 3
mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc View File

@@ -119,9 +119,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
int total_size = input_size_ / sizeof(float); int total_size = input_size_ / sizeof(float);
if (global_step_ >= quant_delay_) { if (global_step_ >= quant_delay_) {
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantPerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else { } else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)), reinterpret_cast<cudaStream_t>(stream_ptr)),


+ 8
- 8
mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.cc View File

@@ -117,10 +117,10 @@ bool FakeQuantPerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs, c
// control flow for quant_delay // control flow for quant_delay
if (global_step_ >= quant_delay_) { if (global_step_ >= quant_delay_) {
// real launch // real launch
CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else { } else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)), reinterpret_cast<cudaStream_t>(stream_ptr)),
@@ -129,10 +129,10 @@ bool FakeQuantPerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs, c
global_step_++; global_step_++;
} else { } else {
// real launch // real launch
CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
} }
return true; return true;


+ 4
- 4
mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc View File

@@ -115,10 +115,10 @@ bool FakeQuantPerLayerGradGpuKernel::Launch(const std::vector<AddressPtr> &input
} }
if (global_step_ >= quant_delay_) { if (global_step_ >= quant_delay_) {
CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizeGrad(input, gradient, output, quant_num_, nudge_min, nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantPerLayerGrad(input, gradient, output, quant_num_, nudge_min, nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else { } else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)), reinterpret_cast<cudaStream_t>(stream_ptr)),


+ 8
- 8
mindspore/train/quant/quant.py View File

@@ -150,7 +150,7 @@ class ConvertToQuantNetwork:
prefix = name prefix = name
add_quant = _AddFakeQuantAfterSubCell(prim_op, add_quant = _AddFakeQuantAfterSubCell(prim_op,
num_bits=self.act_bits, num_bits=self.act_bits,
quant_delay=self.act_delay,
quant_delay=self.act_qdelay,
per_channel=self.act_channel, per_channel=self.act_channel,
symmetric=self.act_symmetric, symmetric=self.act_symmetric,
narrow_range=self.act_range) narrow_range=self.act_range)
@@ -408,19 +408,19 @@ def convert_quant_network(network,


Args: Args:
network (Cell): Obtain a pipeline through network for saving graph summary. network (Cell): Obtain a pipeline through network for saving graph summary.
quant_delay (int or tuple): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False. bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False.
freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0. freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0.
num_bits (int or tuple): Number of bits to use for quantizing weights and activations. The first
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
num_bits (int, list or tuple): Number of bits to use for quantizing weights and activations. The first
element represent weights and second element represent data flow. Default: (8, 8) element represent weights and second element represent data flow. Default: (8, 8)
per_channel (int or tuple): Quantization granularity based on layer or on channel. If `True`
per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
then base on per channel otherwise base on per layer. The first element represent weights then base on per channel otherwise base on per layer. The first element represent weights
and second element represent data flow. Default: (False, False) and second element represent data flow. Default: (False, False)
symmetric (int or tuple): Quantization algorithm use symmetric or not. If `True` then base on
symmetric otherwise base on assymmetric. The first element represent weights and second
symmetric (bool, list or tuple): Quantization algorithm use symmetric or not. If `True` then base on
symmetric otherwise base on asymmetric. The first element represent weights and second
element represent data flow. Default: (False, False) element represent data flow. Default: (False, False)
narrow_range (int or tuple): Quantization algorithm use narrow range or not. If `True` then base
narrow_range (bool, list or tuple): Quantization algorithm use narrow range or not. If `True` then base
on narrow range otherwise base on off narrow range. The first element represent weights and on narrow range otherwise base on off narrow range. The first element represent weights and
second element represent data flow. Default: (False, False) second element represent data flow. Default: (False, False)




+ 625
- 0
tests/st/ops/gpu/test_fake_quant_perchannel.py View File

@@ -0,0 +1,625 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

import mindspore.context as context
from mindspore.common.tensor import Tensor
from mindspore import nn
from mindspore.ops.operations import _quant_ops as Q

context.set_context(device_target='GPU', device_id=0)


class Net(nn.Cell):
def __init__(self, num_bits=8, symmetric=False, narrow_range=False, channel_axis=1):
super(Net, self).__init__()
self.op = Q.FakeQuantPerChannel(num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
channel_axis=channel_axis)

def construct(self, x, minq, maxq):
return self.op(x, minq, maxq)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel1():
# WithVarsPerChannel_ZeroMinAndMax
x = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32)
min_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32)
max_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=False, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel2():
# WithVarsPerChannelDim1NudgedDown_RegularRange
# scale 1/4, zp 0.4, nudge 0. nudged ranges [0.0, 63.75]
x = np.array([-0.1, 0.0, 63.75, 63.8]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.65, 63.65, 63.65, 63.65]).astype(np.float32)
expect = np.array([0.0, 0.0, 63.75, 63.75]).astype(np.float32)

net = Net(num_bits=8, narrow_range=False, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel3():
# WithVarsPerChannelDim1NudgedDown_NarrowRange
# scale 1/4, zp 1.4, nudge 1. nudged ranges[0.0, 63.5]
x = np.array([-0.1, 0.0, 63.5, 63.6]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.4, 63.4, 63.4, 63.4]).astype(np.float32)
expect = np.array([0.0, 0.0, 63.5, 63.5]).astype(np.float32)

net = Net(num_bits=8, narrow_range=True, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel4():
# WithVarsPerChannelDim1NudgedUp_RegularRange
# [-0.125, 63.625]
# scale 1/4, zp: 0.5, nudge 0. nudged range [-0.25, 63.5]
x = np.array([-0.26, -0.25, -0.24, 63.6]).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 63.5]).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.625, 63.625, 63.625, 63.625]).astype(np.float32)

net = Net(num_bits=8, narrow_range=False, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel5():
# WithVarsPerChannelDim1NudgedUp_NarrowRange
# scale 1/4, zp: 1.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.26, -0.25, -0.24, 63.3]).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 63.25]).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.375, 63.375, 63.375, 63.375]).astype(np.float32)

net = Net(num_bits=8, narrow_range=True, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel6():
# WithVarsPerChannelDim2NudgedDown_RegularRange
# scale 1/4, zp: 0.4, nudge 0. nudged range [-0.25, 63.75]
x = np.array([-0.1, 0.0, 0.1, 0.25, 63.75, 63.80]
).reshape(2, 3).astype(np.float32)
expect = np.array([-0.0, 0.0, 0.0, 0.25, 63.75, 63.75]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1]).reshape(3).astype(np.float32)
max_val = np.array([63.65, 63.65, 63.65]).reshape(3).astype(np.float32)

net = Net(num_bits=8, narrow_range=False, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel7():
# WithVarsPerChannelDim2NudgedDown_NarrowRange
# scale 1/4, zp: 1.4, nudge 1. nudged range [-0.25, 63.5]
x = np.array([-0.1, 0.0, 0.1, 0.25, 63.5, 63.6]
).reshape(2, 3).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.25, 63.5, 63.5]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1]).reshape(3).astype(np.float32)
max_val = np.array([63.4, 63.4, 63.4]).reshape(3).astype(np.float32)

net = Net(num_bits=8, narrow_range=True, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel8():
# WithVarsPerChannelDim2NudgedUp_RegularRange
# scale 1/4, zp: 0.5, nudge 1. nudged range [-0.25, 63.5]
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6]
).reshape(2, 3).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 0.0, 63.5, 63.5]
).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125]).reshape(3).astype(np.float32)
max_val = np.array([63.625, 63.625, 63.625]).reshape(3).astype(np.float32)

net = Net(num_bits=8, narrow_range=False, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel9():
# WithVarsPerChannelDim2NudgedUp_NarrowRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3]
).reshape(2, 3).astype(np.float32)
expect = np.array(
[-0.25, -0.25, -0.25, 0.0, 63.25, 63.25]).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125]).reshape(3).astype(np.float32)
max_val = np.array([63.375, 63.375, 63.375]).reshape(3).astype(np.float32)

net = Net(num_bits=8, narrow_range=True, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel10():
# WithVarsPerChannelDim4NudgedDown_RegularRange
# scale 1/4, zp: 0.4, nudge 0. nudged range [-0.25, 63.25]
x = np.array([-0.1, 0.0, 0.1, 0.25, 0.5, 0.75,
1.0, 1.25, 1.5, 1.75, 2.0, 2.25,
63.0, 63.25, 63.5, 63.7, 63.75, 63.8,
63.9, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.25, 0.5, 0.75,
1.0, 1.25, 1.5, 1.75, 2.0, 2.25,
63.0, 63.25, 63.5, 63.75, 63.75, 63.75,
63.75, 63.75, 63.75, 63.75, 63.75, 63.75]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32)
max_val = np.array([63.65, 63.65, 63.65, 63.65]
).reshape(4).astype(np.float32)

net = Net(num_bits=8, narrow_range=False, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect

print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel11():
# WithVarsPerChannelDim4NudgedDown_NarrowRange
# scale 1/4, zp: 1.4, nudge 1. nudged range [0.0, 63.25]
x = np.array([-0.1, 0.0, 0.1, 0.25, 0.5, 0.75,
1.0, 1.25, 1.5, 1.75, 2.0, 2.25,
63.0, 63.25, 63.3, 63.4, 63.5, 63.6,
63.7, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.25, 0.5, 0.75,
1.0, 1.25, 1.5, 1.75, 2.0, 2.25,
63.0, 63.25, 63.25, 63.5, 63.5, 63.5,
63.5, 63.5, 63.5, 63.5, 63.5, 63.5]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32)
max_val = np.array([63.4, 63.4, 63.4, 63.4]).reshape(4).astype(np.float32)

net = Net(num_bits=8, narrow_range=True, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel12():
# WithVarsPerChannelDim4NudgedUp_RegularRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.3, -0.25, -0.2, 0.0, 0.25, 0.5,
0.75, 1.0, 1.25, 1.5, 1.75, 2.0,
63.0, 63.25, 63.4, 63.5, 63.6, 63.7,
100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 0.0, 0.25, 0.5,
0.75, 1.0, 1.25, 1.5, 1.75, 2.0,
63.0, 63.25, 63.5, 63.5, 63.5, 63.5,
63.5, 63.5, 63.5, 63.5, 63.5, 63.5]).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]
).reshape(4).astype(np.float32)
max_val = np.array([63.625, 63.625, 63.625, 63.625]
).reshape(4).astype(np.float32)

net = Net(num_bits=8, narrow_range=False, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel13():
# WithVarsPerChannelDim4NudgedUp_NarrowRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.3, -0.25, -0.2, 0.0, 0.25, 0.5,
0.75, 1.0, 1.25, 1.5, 1.75, 2.0,
63.0, 63.2, 63.25, 63.3, 63.4, 63.5,
100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 0.0, 0.25, 0.5,
0.75, 1.0, 1.25, 1.5, 1.75, 2.0,
63.0, 63.25, 63.25, 63.25, 63.25, 63.25,
63.25, 63.25, 63.25, 63.25, 63.25, 63.25]).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]
).reshape(4).astype(np.float32)
max_val = np.array([63.375, 63.375, 63.375, 63.375]
).reshape(4).astype(np.float32)

net = Net(num_bits=8, narrow_range=True, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel14():
# WithVarsPerChannelDim1NudgedDown_4Bits_RegularRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.1, 0.0, 7.5, 7.6]).reshape(4).astype(np.float32)
expect = np.array([0.0, 0.0, 7.5, 7.5]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32)
max_val = np.array([7.4, 7.4, 7.4, 7.4]).reshape(4).astype(np.float32)

net = Net(num_bits=4, narrow_range=False, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel15():
# WithVarsPerChannelDim1NudgedDown_4Bits_NarrowRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.1, 0.0, 7.0, 7.1]).reshape(4).astype(np.float32)
expect = np.array([0.0, 0.0, 7.0, 7.0]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32)
max_val = np.array([6.9, 6.9, 6.9, 6.9]).reshape(4).astype(np.float32)

net = Net(num_bits=4, narrow_range=True, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel16():
# WithVarsPerChannelDim1NudgedUp_4Bits_RegularRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.6, -0.5, 7.0, 7.1]).reshape(4).astype(np.float32)
expect = np.array([-0.5, -0.5, 7.0, 7.0]).astype(np.float32)
min_val = np.array([-0.4, -0.4, -0.4, -0.4]).reshape(4).astype(np.float32)
max_val = np.array([7.1, 7.1, 7.1, 7.1]).reshape(4).astype(np.float32)

net = Net(num_bits=4, narrow_range=False, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel17():
# WithVarsPerChannelDim1NudgedUp_4Bits_NarrowRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.6, -0.5, 6.5, 6.6]).reshape(4).astype(np.float32)
expect = np.array([-0.5, -0.5, 6.5, 6.5]).astype(np.float32)
min_val = np.array([-0.4, -0.4, -0.4, -0.4]).reshape(4).astype(np.float32)
max_val = np.array([6.6, 6.6, 6.6, 6.6]).reshape(4).astype(np.float32)

net = Net(num_bits=4, narrow_range=True, channel_axis=0)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel18():
# WithVarsPerChannelDim2NudgedDown_4Bits_RegularRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.1, 0.0, 0.1, 0.5, 7.5, 7.6]
).reshape(2, 3).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.5, 7.5, 7.5]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1]).reshape(3).astype(np.float32)
max_val = np.array([7.4, 7.4, 7.4]).reshape(3).astype(np.float32)

net = Net(num_bits=4, narrow_range=False, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel19():
# WithVarsPerChannelDim2NudgedDown_4Bits_NarrowRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.1, 0.0, 0.1, 0.5, 7.0, 7.1]
).reshape(2, 3).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.5, 7.0, 7.0]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1]).reshape(3).astype(np.float32)
max_val = np.array([6.9, 6.9, 6.9]).reshape(3).astype(np.float32)

net = Net(num_bits=4, narrow_range=True, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel20():
# WithVarsPerChannelDim2NudgedUp_4Bits_RegularRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.51, -0.5, -0.24, 0.0, 7.0, 7.1]
).reshape(2, 3).astype(np.float32)
expect = np.array([-0.5, -0.5, 0.0, 0.0, 7.0, 7.0]).astype(np.float32)
min_val = np.array([-0.4, -0.4, -0.4]).reshape(3).astype(np.float32)
max_val = np.array([7.1, 7.1, 7.1]).reshape(3).astype(np.float32)

net = Net(num_bits=4, narrow_range=False, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel21():
# WithVarsPerChannelDim2NudgedUp_4Bits_NarrowRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.6, -0.5, -0.24, 0.0, 6.5, 6.6]
).reshape(2, 3).astype(np.float32)
expect = np.array([-0.5, -0.5, 0.0, 0.0, 6.5, 6.5]).astype(np.float32)
min_val = np.array([-0.4, -0.4, -0.4]).reshape(3).astype(np.float32)
max_val = np.array([6.6, 6.6, 6.6]).reshape(3).astype(np.float32)

net = Net(num_bits=4, narrow_range=True, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel22():
# WithVarsPerChannelDim4NudgedDown_4Bits_RegularRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.1, 0.0, 0.1, 0.5, 1.0, 1.5,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 7.0, 7.4, 7.5, 7.7,
7.8, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.5, 1.0, 1.5,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 7.0, 7.5, 7.5, 7.5,
7.5, 7.5, 7.5, 7.5, 7.5, 7.5]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32)
max_val = np.array([7.4, 7.4, 7.4, 7.4]).reshape(4).astype(np.float32)

net = Net(num_bits=4, narrow_range=False, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel23():
# WithVarsPerChannelDim4NudgedDown_4Bits_NarrowRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.1, 0.0, 0.1, 0.5, 1.0, 1.5,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 6.8, 6.9, 7.0, 7.1,
7.2, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.5, 1.0, 1.5,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 7.0, 7.0, 7.0, 7.0,
7.0, 7.0, 7.0, 7.0, 7.0, 7.0]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32)
max_val = np.array([6.9, 6.9, 6.9, 6.9]).reshape(4).astype(np.float32)

net = Net(num_bits=4, narrow_range=True, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel24():
# WithVarsPerChannelDim4NudgedUp_4Bits_RegularRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.6, -0.5, -0.4, 0.0, 0.5, 1.0,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 6.9, 7.0, 7.1, 7.7,
100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
expect = np.array([-0.5, -0.5, -0.5, 0.0, 0.5, 1.0,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
6.0, 6.5, 7.0, 7.0, 7.0, 7.0,
7.0, 7.0, 7.0, 7.0, 7.0, 7.0]).astype(np.float32)
min_val = np.array([-0.4, -0.4, -0.4, -0.4]).reshape(4).astype(np.float32)
max_val = np.array([7.1, 7.1, 7.1, 7.1]).reshape(4).astype(np.float32)

net = Net(num_bits=4, narrow_range=False, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_perchannel25():
# WithVarsPerChannelDim4NudgedUp_4Bits_NarrowRange
# scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25]
x = np.array([-0.6, -0.5, -0.4, 0.0, 0.5, 1.0,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
5.5, 6.0, 6.4, 6.5, 6.6, 6.7,
100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32)
expect = np.array([-0.5, -0.5, -0.5, 0.0, 0.5, 1.0,
1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
5.5, 6.0, 6.5, 6.5, 6.5, 6.5,
6.5, 6.5, 6.5, 6.5, 6.5, 6.5]).astype(np.float32)
min_val = np.array([-0.4, -0.4, -0.4, -0.4]).reshape(4).astype(np.float32)
max_val = np.array([6.6, 6.6, 6.6, 6.6]).reshape(4).astype(np.float32)

net = Net(num_bits=4, narrow_range=True, channel_axis=1)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)

+ 373
- 0
tests/st/ops/gpu/test_fake_quant_perchannel_grad.py View File

@@ -0,0 +1,373 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops.operations import _quant_ops as Q

context.set_context(device_target='GPU', device_id=0)


class Net(nn.Cell):
def __init__(self, num_bits=8, narrow_range=False):
super(Net, self).__init__()
self.op = Q.FakeQuantPerChannelGrad(
num_bits=num_bits, narrow_range=narrow_range)

def construct(self, dout, x, minq, maxq):
return self.op(dout, x, minq, maxq)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad1():
# WithVarsPerChannelDim1GradientNudgedDown_ZeroMinAndMax
dout = np.random.uniform(-1, 1, size=[4]).astype('float32')
x = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32)
min_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32)
max_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32)
expect = dout

net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad2():
# WithVarsPerChannelDim1GradientNudgedDown_RegularRange
dout = np.random.uniform(-1, 1, size=[4]).astype('float32')
x = np.array([-0.1, 0.0, 63.75, 63.8]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.65, 63.65, 63.65, 63.65]).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad3():
# WithVarsPerChannelDim1GradientNudgedDown_NarrowRange
dout = np.random.uniform(-1, 1, size=[4]).astype('float32')
x = np.array([-0.1, 0.0, 63.5, 63.6]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.4, 63.4, 63.4, 63.4]).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad4():
# WithVarsPerChannelDim1GradientNudgedUp_RegularRange
dout = np.random.uniform(-1, 1, size=[4]).astype('float32')
x = np.array([-0.3, -0.25, 63.5, 63.6]).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.625, 63.625, 63.625, 63.625]).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad5():
# WithVarsPerChannelDim1GradientNudgedUp_NarrowRange
dout = np.random.uniform(-1, 1, size=[4]).astype('float32')
x = np.array([-0.3, -0.25, 63.25, 63.3]).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.375, 63.375, 63.375, 63.375]).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad6():
# WithVarsPerChannelDim2GradientNudgedDown_RegularRange
read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32')
x = np.array([-0.1, 0.0, 0.1, 0.25, 63.75, 63.8]
).reshape(3, 2).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.65, 63.65, 63.65]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], dout[3],
dout[4], 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad7():
# WithVarsPerChannelDim2GradientNudgedDown_NarrowRange
read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32')
x = np.array([-0.1, 0.0, 0.1, 0.25, 63.5, 63.6]
).reshape(3, 2).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.4, 63.4, 63.4]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], dout[3],
dout[4], 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad8():
# WithVarsPerChannelDim2GradientNudgedUp_RegularRange
read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32')
x = np.array([-0.3, -0.25, -0.2, 0.0, 63.5, 63.6]
).reshape(3, 2).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.625, 63.625, 63.625]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], dout[3],
dout[4], 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad9():
# WithVarsPerChannelDim2GradientNudgedUp_NarrowRange
read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32')
x = np.array([-0.3, -0.25, -0.2, 0.0, 63.25, 63.3]
).reshape(3, 2).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.375, 63.375, 63.375]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], dout[3],
dout[4], 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad10():
# WithVarsPerChannelDim4GradientNudgedDown_RegularRange
read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32')
x = np.array([-0.1, 0.0, 63.75, 63.8, -0.1, 0.0,
63.75, 63.8, -0.1, 0.0, 63.75, 63.8,
-0.1, 0.0, 63.75, 63.8, -0.1, 0.0,
63.75, 63.8, -0.1, 0.0, 63.75, 63.8]).reshape(4, 3, 2, 1).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.65, 63.65, 63.65, 63.65]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], 0.0,
0.0, dout[5], dout[6], 0.0,
0.0, dout[9], dout[10], 0.0,
0.0, dout[13], dout[14], 0.0,
0.0, dout[17], dout[18], 0.0,
0.0, dout[21], dout[22], 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad11():
# WithVarsPerChannelDim4GradientNudgedDown_NarrowRange
read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32')
x = np.array([-0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5,
63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6]).reshape(4, 3, 2, 1).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.4, 63.4, 63.4, 63.4]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], 0.0,
0.0, dout[5], dout[6], 0.0,
0.0, dout[9], dout[10], 0.0,
0.0, dout[13], dout[14], 0.0,
0.0, dout[17], dout[18], 0.0,
0.0, dout[21], dout[22], 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad12():
# WithVarsPerChannelDim4GradientNudgedUp_RegularRange
read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32')
x = np.array([-0.3, -0.25, 63.5, 63.6, -0.3, -0.25,
63.5, 63.6, -0.3, -0.25, 63.5, 63.6,
-0.3, -0.25, 63.5, 63.6, -0.3, -0.25,
63.5, 63.6, -0.3, -0.25, 63.5, 63.6]).reshape(4, 3, 2, 1).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.625, 63.625, 63.625, 63.625]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], 0.0,
0.0, dout[5], dout[6], 0.0,
0.0, dout[9], dout[10], 0.0,
0.0, dout[13], dout[14], 0.0,
0.0, dout[17], dout[18], 0.0,
0.0, dout[21], dout[22], 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad13():
# WithVarsPerChannelDim4GradientNudgedUp_NarrowRange
read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32')
x = np.array([-0.3, -0.25, 63.25, 63.3, -0.3, -0.25,
63.25, 63.3, -0.3, -0.25, 63.25, 63.3,
-0.3, -0.25, 63.25, 63.3, -0.3, -0.25,
63.25, 63.3, -0.3, -0.25, 63.25, 63.3]).reshape(4, 3, 2, 1).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.375, 63.375, 63.375, 63.375]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], 0.0,
0.0, dout[5], dout[6], 0.0,
0.0, dout[9], dout[10], 0.0,
0.0, dout[13], dout[14], 0.0,
0.0, dout[17], dout[18], 0.0,
0.0, dout[21], dout[22], 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)

+ 386
- 0
tests/st/ops/gpu/test_fake_quant_perlayer.py View File

@@ -0,0 +1,386 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest

import mindspore.context as context
from mindspore.common.tensor import Tensor
import mindspore.nn as nn
from mindspore.ops.operations import _quant_ops as Q

context.set_context(device_target='GPU', device_id=0)


class Net(nn.Cell):
def __init__(self,
num_bits=8,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True):
super(Net, self).__init__()
self.fake_quant = Q.FakeQuantPerLayer(num_bits=num_bits,
quant_delay=quant_delay,
symmetric=symmetric,
narrow_range=narrow_range,
training=training)

def construct(self, x, minq, maxq):
return self.fake_quant(x, minq, maxq)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant1():
# (8, false, 0.0f, 0.0f, TensorShape({2, 3}),
# {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f},
# {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
x = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).reshape(2, 3).astype(np.float32)
min_val = np.array([0]).reshape(1).astype(np.float32)
max_val = np.array([0]).reshape(1).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant2():
# 8, false, -10.0f, 53.75f, TensorShape({2, 3}),
# {-10.1f, -10.0f, -9.9f, -9.75f, 53.75f, 53.8f},
# {-10.0f, -10.0f, -10.0f, -9.75f, 53.75f, 53.75f});
x = np.array([-10.1, -10.0, -9.9, -9.75, 53.75, 53.8]).reshape(2, 3).astype(np.float32)
min_val = np.array([-10.0]).reshape(1).astype(np.float32)
max_val = np.array([53.75]).reshape(1).astype(np.float32)
expect = np.array([-10.0, -10.0, -10.0, -9.75, 53.75, 53.75]).astype(np.float32)

net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant3():
# WithVarsNoNudging_NarrowRange
x = np.array([-10.1, -10.0, -9.90, -9.75, 53.5, 53.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-10.0]).reshape(1).astype(np.float32)
max_val = np.array([53.5]).reshape(1).astype(np.float32)
expect = np.array([-10.0, -10.0, -10.0, -9.75, 53.5, 53.5]).astype(np.float32)

net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant4():
# WithVarsNudgedDown_RegularRange
x = np.array([-0.1, 0.0, 0.1, 0.25, 63.75, 63.8]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.1]).reshape(1).astype(np.float32)
max_val = np.array([63.65]).reshape(1).astype(np.float32)
expect = np.array([-0.0, 0.0, 0.0, 0.25, 63.75, 63.75]).astype(np.float32)

net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant5():
# WithVarsNudgedDown_NarrowRange
x = np.array([-0.1, 0.0, 0.1, 0.25, 63.5, 63.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.1]).reshape(1).astype(np.float32)
max_val = np.array([63.4]).reshape(1).astype(np.float32)
expect = np.array([-0.0, 0.0, 0.0, 0.25, 63.5, 63.5]).astype(np.float32)

net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant6():
# WithVarsNudgedUp_RegularRange
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.625]).reshape(1).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 0.0, 63.5, 63.5]).astype(np.float32)

net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant7():
# WithVarsNudgedUp_NarrowRange
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.375]).reshape(1).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 0.0, 63.25, 63.25]).astype(np.float32)

net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant8():
# WithVarsNudgedZeroIs255_RegularRange
x = np.array([-63.80, -63.75, -63.70, -63.5, 0.0, 0.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-63.65]).reshape(1).astype(np.float32)
max_val = np.array([0.1]).reshape(1).astype(np.float32)
expect = np.array([-63.75, -63.75, -63.75, -63.5, 0.0, 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant9():
# WithVarsNudgedZeroIs255_NarrowRange
x = np.array([-63.6, -63.5, -63.4, -63.25, 0.0, 0.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-63.4]).reshape(1).astype(np.float32)
max_val = np.array([0.1]).reshape(1).astype(np.float32)
expect = np.array([-63.5, -63.5, -63.5, -63.25, 0.0, 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant10():
# WithVarsNoNudging_4Bits_RegularRange
x = np.array([-6.1, -6.0, -5.9, -5.5, 1.5, 1.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-6.0]).reshape(1).astype(np.float32)
max_val = np.array([1.5]).reshape(1).astype(np.float32)
expect = np.array([-6.0, -6.0, -6.0, -5.5, 1.5, 1.5]).astype(np.float32)

net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant11():
# WithVarsNoNudging_4Bits_NarrowRange
x = np.array([-6.1, -6.0, -5.9, -5.5, 1.0, 1.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-6.0]).reshape(1).astype(np.float32)
max_val = np.array([1.0]).reshape(1).astype(np.float32)
expect = np.array([-6.0, -6.0, -6.0, -5.5, 1.0, 1.0]).astype(np.float32)

net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant12():
# WithVarsNudgedDown_4Bits_RegularRange
x = np.array([-0.1, 0.0, 0.1, 0.5, 7.5, 7.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.1]).reshape(1).astype(np.float32)
max_val = np.array([7.4]).reshape(1).astype(np.float32)
expect = np.array([-0.0, 0.0, 0.0, 0.5, 7.5, 7.5]).astype(np.float32)

net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant13():
# WithVarsNudgedDown_4Bits_NarrowRange
x = np.array([-0.1, 0.0, 0.1, 0.5, 7.0, 7.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.1]).reshape(1).astype(np.float32)
max_val = np.array([6.9]).reshape(1).astype(np.float32)
expect = np.array([-0.0, 0.0, 0.0, 0.5, 7.0, 7.0]).astype(np.float32)

net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant14():
# WithVarsNudgedUp_4Bits_RegularRange
x = np.array([-0.6, -0.5, -0.24, 0.0, 7.0, 7.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([7.1]).reshape(1).astype(np.float32)
expect = np.array([-0.5, -0.5, -0.00, 0.0, 7.0, 7.0]).astype(np.float32)

net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant15():
# WithVarsNudgedUp_4Bits_NarrowRange
x = np.array([-0.6, -0.5, -0.24, 0.0, 6.5, 6.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([6.6]).reshape(1).astype(np.float32)
expect = np.array([-0.5, -0.5, -0.00, 0.0, 6.5, 6.5]).astype(np.float32)

net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant16():
# WithVarsNudgedZero15_4Bits_RegularRange
x = np.array([-7.6, -7.5, -7.4, -7.2, 0.0, 0.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-7.3]).reshape(1).astype(np.float32)
max_val = np.array([0.2]).reshape(1).astype(np.float32)
expect = np.array([-7.5, -7.5, -7.5, -7.0, 0.0, 0.0]).astype(np.float32)

net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant17():
# WithVarsNudgedZero15_4Bits_NarrowRange
x = np.array([-7.1, -7.0, -6.9, -6.5, 0.0, 0.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-6.8]).reshape(1).astype(np.float32)
max_val = np.array([0.2]).reshape(1).astype(np.float32)
expect = np.array([-7.0, -7.0, -7.0, -6.5, 0.0, 0.0]).astype(np.float32)

net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)

+ 221
- 0
tests/st/ops/gpu/test_fake_quant_perlayer_grad.py View File

@@ -0,0 +1,221 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

import numpy as np
import pytest
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops.operations import _quant_ops as Q

context.set_context(device_target='GPU', device_id=0)


class Net(nn.Cell):
def __init__(self, num_bits=8, narrow_range=False):
super(Net, self).__init__()
self.op = Q.FakeQuantPerLayerGrad(num_bits=num_bits, narrow_range=narrow_range)

def construct(self, dout, x, minq, maxq):
return self.op(dout, x, minq, maxq)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad1():
# WithArgsGradient RegularRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6]).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.625]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad2():
# WithArgsGradient NarrowRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3]).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.375]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad3():
# WithArgsGradient_4Bits_RegularRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.6, -0.5, -0.4, 0.0, 7.0, 7.1]).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([7.1]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)

net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad4():
# WithArgsGradient_4Bits_NarrowRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.6, -0.5, -0.4, 0.0, 6.5, 6.6]).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([6.6]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)

net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad5():
# FakeQuantWithMinMaxVarsGradient
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).astype(np.float32)
min_val = np.array([0.0]).reshape(1).astype(np.float32)
max_val = np.array([0.0]).reshape(1).astype(np.float32)
expect = dout

net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad6():
# WithVarsGradient_RegularRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6]).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.625]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad7():
# WithVarsGradient_NarrowRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3]).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.375]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)

net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad8():
# WithVarsGradient_4Bits_RegularRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.6, -0.5, -0.4, 0.0, 7.0, 7.1]).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([7.1]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)

net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad9():
# WithVarsGradient_4Bits_NarrowRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.6, -0.5, -0.4, 0.0, 6.5, 6.6]).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([6.6]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)

net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))

error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)

Loading…
Cancel
Save