Merge pull request !1032 from SanjayChan/bug_fixtags/v0.3.0-alpha
| @@ -20,8 +20,8 @@ | |||||
| #include "device/gpu/cuda_common.h" | #include "device/gpu/cuda_common.h" | ||||
| #include "fake_quant_impl.cuh" | #include "fake_quant_impl.cuh" | ||||
| __global__ void FakeQuantize(const float* input, float* output, const int size, const float* nudge_min, | |||||
| const float* nudge_max, const float* scale, bool symmetric) { | |||||
| __global__ void FakeQuantize(const float *input, float *output, const int size, const float *nudge_min, | |||||
| const float *nudge_max, const float *scale, bool symmetric) { | |||||
| float input_x = 0.f; | float input_x = 0.f; | ||||
| int nudge_input = 0; | int nudge_input = 0; | ||||
| @@ -43,8 +43,8 @@ __global__ void FakeQuantize(const float* input, float* output, const int size, | |||||
| return; | return; | ||||
| } | } | ||||
| __global__ void FakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size, | |||||
| const float* nudge_min, const float* nudge_max) { | |||||
| __global__ void FakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size, | |||||
| const float *nudge_min, const float *nudge_max) { | |||||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { | ||||
| if (input[i] < nudge_min[0] || input[i] > nudge_max[0]) { | if (input[i] < nudge_min[0] || input[i] > nudge_max[0]) { | ||||
| output[i] = 0; | output[i] = 0; | ||||
| @@ -55,15 +55,18 @@ __global__ void FakeQuantizeGrad(const float* input, const float* gradient, floa | |||||
| return; | return; | ||||
| } | } | ||||
| __global__ void NudgeMinMax(const float* input_min, const float* input_max, const float quant_min, | |||||
| const float quant_max, float* nudge_min, float* nudge_max, float* scale) { | |||||
| __global__ void NudgeMinMax(const float *input_min, const float *input_max, const float quant_min, | |||||
| const float quant_max, float *nudge_min, float *nudge_max, float *scale) { | |||||
| float zp_from_min = 0.f; | float zp_from_min = 0.f; | ||||
| if ((quant_max - quant_min) == 0 || (*input_max - *input_min) == 0) { | |||||
| *scale = 0.f; | |||||
| scale[0] = 0.f; | |||||
| nudge_max[0] = 0.f; | |||||
| nudge_min[0] = 0.f; | |||||
| if ((quant_max - quant_min) == 0 || (input_max[0] - input_min[0]) == 0) { | |||||
| scale[0] = 0.f; | |||||
| zp_from_min = 0.f; | zp_from_min = 0.f; | ||||
| } else { | } else { | ||||
| *scale = (*input_max - *input_min) / (quant_max - quant_min); | |||||
| zp_from_min = quant_min - *input_min / *scale; | |||||
| scale[0] = (input_max[0] - input_min[0]) / (quant_max - quant_min); | |||||
| zp_from_min = quant_min - input_min[0] / scale[0]; | |||||
| } | } | ||||
| float nudge_zp = 0.f; | float nudge_zp = 0.f; | ||||
| @@ -75,59 +78,59 @@ __global__ void NudgeMinMax(const float* input_min, const float* input_max, cons | |||||
| nudge_zp = round(zp_from_min); | nudge_zp = round(zp_from_min); | ||||
| } | } | ||||
| *nudge_min = (quant_min - nudge_zp) * (*scale); | |||||
| *nudge_max = (quant_max - nudge_zp) * (*scale); | |||||
| nudge_min[0] = (quant_min - nudge_zp) * (scale[0]); | |||||
| nudge_max[0] = (quant_max - nudge_zp) * (scale[0]); | |||||
| return; | return; | ||||
| } | } | ||||
| __global__ void UpdateInputMinMaxWithEMA(float* input_min, float* input_max, const float min, const float max, | |||||
| __global__ void UpdateInputMinMaxWithEMA(float *input_min, float *input_max, const float min, const float max, | |||||
| const float decay) { | const float decay) { | ||||
| *input_min = decay * (min) + (1 - decay) * (*input_min); | |||||
| *input_min = *input_min > 0 ? 0 : *input_min; | |||||
| *input_max = decay * (max) + (1 - decay) * (*input_max); | |||||
| *input_max = *input_max < 0 ? 0 : *input_max; | |||||
| input_min[0] = decay * (min) + (1 - decay) * (input_min[0]); | |||||
| input_min[0] = input_min[0] > 0 ? 0 : input_min[0]; | |||||
| input_max[0] = decay * (max) + (1 - decay) * (input_max[0]); | |||||
| input_max[0] = input_max[0] < 0 ? 0 : input_max[0]; | |||||
| return; | return; | ||||
| } | } | ||||
| __global__ void UpdateInputMinMax(float* input_min, float* input_max, const float min, const float max) { | |||||
| *input_min = min; | |||||
| *input_max = max; | |||||
| __global__ void UpdateInputMinMax(float *input_min, float *input_max, const float min, const float max) { | |||||
| input_min[0] = min > 0 ? 0 : min; | |||||
| input_max[0] = max < 0 ? 0 : max; | |||||
| } | } | ||||
| void CalFakeQuantize(const float* input, float* output, const int size, const float* nudge_min, const float* nudge_max, | |||||
| const float* scale, bool symmetric, cudaStream_t cuda_stream) { | |||||
| void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max, | |||||
| const float *scale, bool symmetric, cudaStream_t cuda_stream) { | |||||
| FakeQuantize<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, output, size, nudge_min, nudge_max, scale, | FakeQuantize<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, output, size, nudge_min, nudge_max, scale, | ||||
| symmetric); | symmetric); | ||||
| return; | return; | ||||
| } | } | ||||
| void CalFakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size, | |||||
| const float* nudge_min, const float* nudge_max, cudaStream_t cuda_stream) { | |||||
| void CalFakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size, | |||||
| const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream) { | |||||
| FakeQuantizeGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, gradient, output, size, nudge_min, | FakeQuantizeGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, gradient, output, size, nudge_min, | ||||
| nudge_max); | nudge_max); | ||||
| return; | return; | ||||
| } | } | ||||
| void CalNudge(const float* input_min, const float* input_max, const float quant_min, const float quant_max, | |||||
| float* nudge_min, float* nudge_max, float* scale, cudaStream_t cuda_stream) { | |||||
| NudgeMinMax<<<1, 1>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale); | |||||
| void CalNudge(const float *input_min, const float *input_max, const float quant_min, const float quant_max, | |||||
| float *nudge_min, float *nudge_max, float *scale, cudaStream_t cuda_stream) { | |||||
| NudgeMinMax<<<1, 1, 0, cuda_stream>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale); | |||||
| return; | return; | ||||
| } | } | ||||
| void CalMinMax(float* input, float* input_min, float* input_max, const int size, const float ema_decay, const bool ema, | |||||
| void CalMinMax(float *input, float *input_min, float *input_max, const int size, const float ema_decay, const bool ema, | |||||
| cudaStream_t cuda_stream) { | cudaStream_t cuda_stream) { | ||||
| float minel = 0.f; | float minel = 0.f; | ||||
| float maxel = 0.f; | float maxel = 0.f; | ||||
| auto policy = thrust::cuda::par.on(cuda_stream); | |||||
| thrust::pair<thrust::device_ptr<float>, thrust::device_ptr<float>> tuple; | thrust::pair<thrust::device_ptr<float>, thrust::device_ptr<float>> tuple; | ||||
| tuple = thrust::minmax_element(thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + size); | |||||
| tuple = thrust::minmax_element(policy, thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + size); | |||||
| minel = tuple.first[0]; | minel = tuple.first[0]; | ||||
| maxel = tuple.second[0]; | maxel = tuple.second[0]; | ||||
| if (ema) { | if (ema) { | ||||
| UpdateInputMinMaxWithEMA<<<1, 1>>>(input_min, input_max, minel, maxel, ema_decay); | |||||
| UpdateInputMinMaxWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, minel, maxel, ema_decay); | |||||
| } else { | } else { | ||||
| UpdateInputMinMax<<<1, 1>>>(input_min, input_max, minel, maxel); | |||||
| UpdateInputMinMax<<<1, 1, 0, cuda_stream>>>(input_min, input_max, minel, maxel); | |||||
| } | } | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -17,16 +17,16 @@ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | ||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | ||||
| void CalFakeQuantize(const float* input, float* output, const int size, const float* nudge_min, const float* nudge_max, | |||||
| const float* scale, bool symmetric, cudaStream_t cuda_stream); | |||||
| void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max, | |||||
| const float *scale, bool symmetric, cudaStream_t cuda_stream); | |||||
| void CalFakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size, | |||||
| const float* nudge_min, const float* nudge_max, cudaStream_t cuda_stream); | |||||
| void CalFakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size, | |||||
| const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream); | |||||
| void CalNudge(const float* input_min, const float* input_max, const float quant_min, const float quant_max, | |||||
| float* nudge_min, float* nudge_max, float* scale, cudaStream_t cuda_stream); | |||||
| void CalNudge(const float *input_min, const float *input_max, const float quant_min, const float quant_max, | |||||
| float *nudge_min, float *nudge_max, float *scale, cudaStream_t cuda_stream); | |||||
| void CalMinMax(float* input, float* input_min, float* input_max, const int size, const float ema_decay, const bool ema, | |||||
| void CalMinMax(float *input, float *input_min, float *input_max, const int size, const float ema_decay, const bool ema, | |||||
| cudaStream_t cuda_stream); | cudaStream_t cuda_stream); | ||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ | ||||
| @@ -34,8 +34,8 @@ | |||||
| * @param channel_num | * @param channel_num | ||||
| * @return | * @return | ||||
| */ | */ | ||||
| __global__ void NudgeMinMaxPerChannel(const float* input_min, const float* input_max, const float quant_min, | |||||
| const float quant_max, float* nudge_min, float* nudge_max, float* scale, | |||||
| __global__ void NudgeMinMaxPerChannel(const float *input_min, const float *input_max, const float quant_min, | |||||
| const float quant_max, float *nudge_min, float *nudge_max, float *scale, | |||||
| int channel_num) { | int channel_num) { | ||||
| float zp_from_min = 0.f; | float zp_from_min = 0.f; | ||||
| float nudge_zp = 0.f; | float nudge_zp = 0.f; | ||||
| @@ -62,8 +62,8 @@ __global__ void NudgeMinMaxPerChannel(const float* input_min, const float* input | |||||
| } | } | ||||
| } | } | ||||
| void CalNudgePerChannel(const float* input_min, const float* input_max, const float quant_min, const float quant_max, | |||||
| float* nudge_min, float* nudge_max, float* scale, const int channel_num, | |||||
| void CalNudgePerChannel(const float *input_min, const float *input_max, const float quant_min, const float quant_max, | |||||
| float *nudge_min, float *nudge_max, float *scale, const int channel_num, | |||||
| cudaStream_t cuda_stream) { | cudaStream_t cuda_stream) { | ||||
| NudgeMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>( | NudgeMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>( | ||||
| input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, channel_num); | input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, channel_num); | ||||
| @@ -80,8 +80,8 @@ void CalNudgePerChannel(const float* input_min, const float* input_max, const fl | |||||
| * @param scale - array | * @param scale - array | ||||
| * @return | * @return | ||||
| */ | */ | ||||
| __global__ void FakeQuantizePerChannel(const float* input, float* output, const int total_size, const int channel_size, | |||||
| const float* nudge_min, const float* nudge_max, const float* scale, | |||||
| __global__ void FakeQuantizePerChannel(const float *input, float *output, const int total_size, const int channel_size, | |||||
| const float *nudge_min, const float *nudge_max, const float *scale, | |||||
| bool symmetric) { | bool symmetric) { | ||||
| float input_x = 0.f; | float input_x = 0.f; | ||||
| int nudge_input = 0; | int nudge_input = 0; | ||||
| @@ -106,8 +106,8 @@ __global__ void FakeQuantizePerChannel(const float* input, float* output, const | |||||
| } | } | ||||
| } | } | ||||
| void CalFakeQuantizePerChannel(const float* input, float* output, const int total_size, const int channel_size, | |||||
| const float* nudge_min, const float* nudge_max, const float* scale, bool symmetric, | |||||
| void CalFakeQuantizePerChannel(const float *input, float *output, const int total_size, const int channel_size, | |||||
| const float *nudge_min, const float *nudge_max, const float *scale, bool symmetric, | |||||
| cudaStream_t cuda_stream) { | cudaStream_t cuda_stream) { | ||||
| FakeQuantizePerChannel<<<GET_BLOCKS(total_size), GET_THREADS, 0, cuda_stream>>>( | FakeQuantizePerChannel<<<GET_BLOCKS(total_size), GET_THREADS, 0, cuda_stream>>>( | ||||
| input, output, total_size, channel_size, nudge_min, nudge_max, scale, symmetric); | input, output, total_size, channel_size, nudge_min, nudge_max, scale, symmetric); | ||||
| @@ -121,10 +121,10 @@ void CalFakeQuantizePerChannel(const float* input, float* output, const int tota | |||||
| * @param max | * @param max | ||||
| * @return | * @return | ||||
| */ | */ | ||||
| __global__ void UpdateInputMinMaxPerChannel(float* input_min, float* input_max, float* input, int channels, | |||||
| __global__ void UpdateInputMinMaxPerChannel(float *input_min, float *input_max, float *input, int channels, | |||||
| int per_channel_nums, bool ema, float ema_decay) { | int per_channel_nums, bool ema, float ema_decay) { | ||||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) { | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) { | ||||
| thrust::pair<float*, float*> sum = | |||||
| thrust::pair<float *, float *> sum = | |||||
| thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1)); | thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1)); | ||||
| if (ema) { | if (ema) { | ||||
| input_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i]; | input_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i]; | ||||
| @@ -133,25 +133,27 @@ __global__ void UpdateInputMinMaxPerChannel(float* input_min, float* input_max, | |||||
| input_min[i] = sum.first[0]; | input_min[i] = sum.first[0]; | ||||
| input_max[i] = sum.second[0]; | input_max[i] = sum.second[0]; | ||||
| } | } | ||||
| input_min[i] = input_min[i] > 0 ? 0 : input_min[i]; | |||||
| input_max[i] = input_max[i] < 0 ? 0 : input_max[i]; | |||||
| } | } | ||||
| } | } | ||||
| __global__ void UpdateInputMinMaxPerChannelWithEMA(float* input_min, float* input_max, float min, float max, | |||||
| __global__ void UpdateInputMinMaxPerChannelWithEMA(float *input_min, float *input_max, float min, float max, | |||||
| const float decay) { | const float decay) { | ||||
| *input_min = decay * (min) + (1 - decay) * (*input_min); | *input_min = decay * (min) + (1 - decay) * (*input_min); | ||||
| *input_max = decay * (max) + (1 - decay) * (*input_max); | *input_max = decay * (max) + (1 - decay) * (*input_max); | ||||
| } | } | ||||
| void CalMinMaxPerChannel(float* input, float* input_min, float* input_max, const int total_size, const int channel_size, | |||||
| void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, const int total_size, const int channel_size, | |||||
| const float ema_decay, const bool ema, cudaStream_t cuda_stream) { | const float ema_decay, const bool ema, cudaStream_t cuda_stream) { | ||||
| int per_channel_num = total_size / channel_size; | int per_channel_num = total_size / channel_size; | ||||
| UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>( | UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>( | ||||
| input_min, input_max, input, channel_size, per_channel_num, ema, ema_decay); | input_min, input_max, input, channel_size, per_channel_num, ema, ema_decay); | ||||
| } | } | ||||
| __global__ void FakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output, | |||||
| const int total_size, const int channel_size, const float* nudge_min, | |||||
| const float* nudge_max) { | |||||
| __global__ void FakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output, | |||||
| const int total_size, const int channel_size, const float *nudge_min, | |||||
| const float *nudge_max) { | |||||
| int channel_idx = 0; | int channel_idx = 0; | ||||
| int per_channel_num = total_size / channel_size; | int per_channel_num = total_size / channel_size; | ||||
| @@ -165,10 +167,9 @@ __global__ void FakeQuantizePerChannelGrad(const float* input, const float* grad | |||||
| } | } | ||||
| } | } | ||||
| void CalFakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output, const int total_num, | |||||
| const int channel_num, const float* nudge_min, const float* nudge_max, | |||||
| void CalFakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output, const int total_num, | |||||
| const int channel_num, const float *nudge_min, const float *nudge_max, | |||||
| cudaStream_t cuda_stream) { | cudaStream_t cuda_stream) { | ||||
| FakeQuantizePerChannelGrad<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>( | FakeQuantizePerChannelGrad<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>( | ||||
| input, gradient, output, total_num, channel_num, nudge_min, nudge_max); | input, gradient, output, total_num, channel_num, nudge_min, nudge_max); | ||||
| } | } | ||||
| @@ -114,8 +114,7 @@ class BatchNormFold2GpuKernel : public GpuKernel { | |||||
| output_size_list_.push_back(input_size); | output_size_list_.push_back(input_size); | ||||
| size_t workspace_size = 0; | |||||
| workspace_size_list_.push_back(workspace_size); | |||||
| workspace_size_list_.push_back(sizeof(int32_t)); | |||||
| } | } | ||||
| private: | private: | ||||
| @@ -70,9 +70,12 @@ class BatchNormFold2GradGpuKernel : public GpuKernel { | |||||
| int32_t current_step_host[1]; | int32_t current_step_host[1]; | ||||
| size_t x_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); | size_t x_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); | ||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(current_step_host, global_step, sizeof(int32_t), cudaMemcpyDeviceToHost), | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, global_step, sizeof(int32_t), cudaMemcpyDeviceToHost, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Failed to copy gpu memory."); | "Failed to copy gpu memory."); | ||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(d_x, dout, x_size, cudaMemcpyDeviceToDevice), "Failed to copy gpu memory."); | |||||
| CHECK_CUDA_RET_WITH_ERROR( | |||||
| cudaMemcpyAsync(d_x, dout, x_size, cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Failed to copy gpu memory."); | |||||
| BatchNormFold2GradReduce(dout, x, d_beta, tmp, reduce_x, tmp2, tmp_x, batch_size_, channel_, height_, width_, | BatchNormFold2GradReduce(dout, x, d_beta, tmp, reduce_x, tmp2, tmp_x, batch_size_, channel_, height_, width_, | ||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| @@ -55,12 +55,13 @@ class BatchNormFoldGpuKernel : public GpuKernel { | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override { | ||||
| (void)workspace; | (void)workspace; | ||||
| auto x = reinterpret_cast<T *>(inputs[0]->addr); | |||||
| auto mean = reinterpret_cast<T *>(inputs[1]->addr); | |||||
| auto variance = reinterpret_cast<T *>(inputs[2]->addr); | |||||
| int *current_step = reinterpret_cast<int *>(inputs[3]->addr); | |||||
| auto x = GetDeviceAddress<T>(inputs, 0); | |||||
| auto mean = GetDeviceAddress<T>(inputs, 1); | |||||
| auto variance = GetDeviceAddress<T>(inputs, 2); | |||||
| int *current_step = GetDeviceAddress<int>(inputs, 3); | |||||
| int current_step_host[1]; | int current_step_host[1]; | ||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost), | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Copy gpu memoy failed."); | "Copy gpu memoy failed."); | ||||
| if (x == nullptr) { | if (x == nullptr) { | ||||
| MS_LOG(ERROR) << "BatchNormFoldGpuKernel x is null."; | MS_LOG(ERROR) << "BatchNormFoldGpuKernel x is null."; | ||||
| @@ -78,15 +79,17 @@ class BatchNormFoldGpuKernel : public GpuKernel { | |||||
| MS_LOG(ERROR) << "BatchNormFoldGpuKernel current_step is null."; | MS_LOG(ERROR) << "BatchNormFoldGpuKernel current_step is null."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto batch_mean = reinterpret_cast<T *>(outputs[0]->addr); | |||||
| auto batch_std = reinterpret_cast<T *>(outputs[1]->addr); | |||||
| auto running_mean = reinterpret_cast<T *>(outputs[2]->addr); | |||||
| auto running_std = reinterpret_cast<T *>(outputs[3]->addr); | |||||
| auto y = reinterpret_cast<T *>(workspace[0]->addr); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(running_mean, mean, output_size_, cudaMemcpyDeviceToDevice), | |||||
| auto batch_mean = GetDeviceAddress<T>(outputs, 0); | |||||
| auto batch_std = GetDeviceAddress<T>(outputs, 1); | |||||
| auto running_mean = GetDeviceAddress<T>(outputs, 2); | |||||
| auto running_std = GetDeviceAddress<T>(outputs, 3); | |||||
| auto y = GetDeviceAddress<T>(workspace, 0); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(running_mean, mean, output_size_, cudaMemcpyDeviceToDevice, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Failed to copy gpu memory."); | "Failed to copy gpu memory."); | ||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(running_std, variance, output_size_, cudaMemcpyDeviceToDevice), | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(running_std, variance, output_size_, cudaMemcpyDeviceToDevice, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Failed to copy gpu memory."); | "Failed to copy gpu memory."); | ||||
| CalUpdateRunningStd(channel_, epsilon_, running_std, reinterpret_cast<cudaStream_t>(stream_ptr)); | CalUpdateRunningStd(channel_, epsilon_, running_std, reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| if (!is_training_ || current_step_host[0] >= freeze_bn_) { | if (!is_training_ || current_step_host[0] >= freeze_bn_) { | ||||
| @@ -57,7 +57,8 @@ class BatchNormFoldGradGpuKernel : public GpuKernel { | |||||
| T *batch_std = GetDeviceAddress<T>(inputs, 4); | T *batch_std = GetDeviceAddress<T>(inputs, 4); | ||||
| int *current_step = GetDeviceAddress<int>(inputs, 5); | int *current_step = GetDeviceAddress<int>(inputs, 5); | ||||
| int current_step_host[1]; | int current_step_host[1]; | ||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost), | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Copy gpu memoy failed."); | "Copy gpu memoy failed."); | ||||
| if (d_batch_mean == nullptr) { | if (d_batch_mean == nullptr) { | ||||
| MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel d_batch_mean is null."; | MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel d_batch_mean is null."; | ||||
| @@ -83,7 +84,7 @@ class BatchNormFoldGradGpuKernel : public GpuKernel { | |||||
| MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel current_step is null."; | MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel current_step is null."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| T *dx = reinterpret_cast<T *>(outputs[0]->addr); | |||||
| T *dx = GetDeviceAddress<T>(outputs, 0); | |||||
| if (!is_training_ || current_step_host[0] >= freeze_bn_) { | if (!is_training_ || current_step_host[0] >= freeze_bn_) { | ||||
| ThrustFillWith(dx, batch_ * channel_ * height_ * width_, 0.f, reinterpret_cast<cudaStream_t>(stream_ptr)); | ThrustFillWith(dx, batch_ * channel_ * height_ * width_, 0.f, reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| @@ -60,7 +60,7 @@ bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) { | |||||
| num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); | num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); | ||||
| ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); | ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); | ||||
| ema_decay_ = 1.0 - GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); | |||||
| ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); | |||||
| training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); | training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); | ||||
| if (num_bits_ <= 2 || num_bits_ >= 16) { | if (num_bits_ <= 2 || num_bits_ >= 16) { | ||||
| @@ -115,7 +115,6 @@ void FakeQuantGpuKernel::InitSizeLists() { | |||||
| bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | ||||
| (void)workspace; | |||||
| float *output = GetDeviceAddress<float>(outputs, 0); | float *output = GetDeviceAddress<float>(outputs, 0); | ||||
| float *input = GetDeviceAddress<float>(inputs, 0); | float *input = GetDeviceAddress<float>(inputs, 0); | ||||
| float *input_min = GetDeviceAddress<float>(inputs, 1); | float *input_min = GetDeviceAddress<float>(inputs, 1); | ||||
| @@ -151,7 +150,8 @@ bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std | |||||
| CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_, | CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_, | ||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| } else { | } else { | ||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, input, input_size_, cudaMemcpyDeviceToDevice), | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Copy gpu memory failed"); | "Copy gpu memory failed"); | ||||
| } | } | ||||
| global_step_++; | global_step_++; | ||||
| @@ -93,7 +93,6 @@ void FakeQuantGradGpuKernel::InitSizeLists() { | |||||
| bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||
| const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) { | ||||
| (void)workspace; | |||||
| float *output = GetDeviceAddress<float>(outputs, 0); | float *output = GetDeviceAddress<float>(outputs, 0); | ||||
| float *gradient = GetDeviceAddress<float>(inputs, 0); | float *gradient = GetDeviceAddress<float>(inputs, 0); | ||||
| float *input = GetDeviceAddress<float>(inputs, 1); | float *input = GetDeviceAddress<float>(inputs, 1); | ||||
| @@ -133,8 +132,9 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed"); | CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed"); | ||||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed"); | CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed"); | ||||
| } else { | } else { | ||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, gradient, input_size_, cudaMemcpyDeviceToDevice), | |||||
| "Copy gpu memory failed."); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Copy gpu memory failed"); | |||||
| } | } | ||||
| global_step_++; | global_step_++; | ||||
| return true; | return true; | ||||
| @@ -107,11 +107,13 @@ bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) { | |||||
| } | } | ||||
| void FakeQuantPerChannelGpuKernel::InitSizeLists() { | void FakeQuantPerChannelGpuKernel::InitSizeLists() { | ||||
| input_size_list_.push_back(input_size_); // input | |||||
| input_size_list_.push_back(min_size_); // min | |||||
| input_size_list_.push_back(max_size_); // max | |||||
| output_size_list_.push_back(output_size_); | |||||
| workspace_size_list_.push_back(workspace_size_); | |||||
| input_size_list_.push_back(input_size_); // input in tensor | |||||
| input_size_list_.push_back(min_size_); // min one scalar | |||||
| input_size_list_.push_back(max_size_); // max on scalar | |||||
| output_size_list_.push_back(output_size_); // output in tensor | |||||
| workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel | |||||
| workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel | |||||
| workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel | |||||
| } | } | ||||
| void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForTraining(float *input, float *output, float *input_min, | void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForTraining(float *input, float *output, float *input_min, | ||||
| @@ -128,8 +130,9 @@ void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForTraining(float *input, floa | |||||
| CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max, | CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max, | ||||
| d_scale, symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr)); | d_scale, symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| } else { | } else { | ||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, input, input_size_, cudaMemcpyDeviceToDevice), | |||||
| "Copy gpu memory failed."); | |||||
| CHECK_CUDA_RET_WITH_ERROR( | |||||
| cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Copy gpu memory failed."); | |||||
| } | } | ||||
| global_step_++; | global_step_++; | ||||
| } | } | ||||
| @@ -152,6 +155,9 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs, | |||||
| float *input = GetDeviceAddress<float>(inputs, 0); | float *input = GetDeviceAddress<float>(inputs, 0); | ||||
| float *input_min = GetDeviceAddress<float>(inputs, 1); | float *input_min = GetDeviceAddress<float>(inputs, 1); | ||||
| float *input_max = GetDeviceAddress<float>(inputs, 2); | float *input_max = GetDeviceAddress<float>(inputs, 2); | ||||
| float *d_scale = GetDeviceAddress<float>(workspace, 0); | |||||
| float *d_nudge_min = GetDeviceAddress<float>(workspace, 1); | |||||
| float *d_nudge_max = GetDeviceAddress<float>(workspace, 2); | |||||
| if (input == nullptr) { | if (input == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null."; | MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null."; | ||||
| @@ -160,27 +166,12 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs, | |||||
| MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input min or max is null."; | MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input min or max is null."; | ||||
| } | } | ||||
| // Allocate space for device copies | |||||
| float *d_scale = nullptr; | |||||
| float *d_nudge_min = nullptr; | |||||
| float *d_nudge_max = nullptr; | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), sizeof(float) * channel_out_), | |||||
| "Malloc gpu memory failed"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), sizeof(float) * channel_out_), | |||||
| "Malloc gpu memory failed"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), sizeof(float) * channel_out_), | |||||
| "Malloc gpu memory failed"); | |||||
| if (training_) { | if (training_) { | ||||
| CalFakeQuantizeForTraining(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr); | CalFakeQuantizeForTraining(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr); | ||||
| } else { | } else { | ||||
| CalFakeQuantizeForInfer(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr); | CalFakeQuantizeForInfer(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr); | ||||
| } | } | ||||
| // Cleanup | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed"); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -97,7 +97,9 @@ void FakeQuantPerChannelGradGpuKernel::InitSizeLists() { | |||||
| input_size_list_.push_back(min_size_); // min | input_size_list_.push_back(min_size_); // min | ||||
| input_size_list_.push_back(max_size_); // max | input_size_list_.push_back(max_size_); // max | ||||
| output_size_list_.push_back(output_size_); | output_size_list_.push_back(output_size_); | ||||
| workspace_size_list_.push_back(workspace_size_); | |||||
| workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel | |||||
| workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel | |||||
| workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel | |||||
| } | } | ||||
| bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, | bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, | ||||
| @@ -109,6 +111,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp | |||||
| float *input = GetDeviceAddress<float>(inputs, 1); | float *input = GetDeviceAddress<float>(inputs, 1); | ||||
| float *input_min = GetDeviceAddress<float>(inputs, 2); | float *input_min = GetDeviceAddress<float>(inputs, 2); | ||||
| float *input_max = GetDeviceAddress<float>(inputs, 3); | float *input_max = GetDeviceAddress<float>(inputs, 3); | ||||
| float *d_scale = GetDeviceAddress<float>(workspace, 0); | |||||
| float *d_nudge_min = GetDeviceAddress<float>(workspace, 1); | |||||
| float *d_nudge_max = GetDeviceAddress<float>(workspace, 2); | |||||
| if (gradient == nullptr) { | if (gradient == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel gradient is null"; | MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel gradient is null"; | ||||
| @@ -125,28 +130,13 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp | |||||
| int total_size = input_size_ / sizeof(float); | int total_size = input_size_ / sizeof(float); | ||||
| if (global_step_ >= quant_delay_) { | if (global_step_ >= quant_delay_) { | ||||
| float *d_scale = nullptr; | |||||
| float *d_nudge_min = nullptr; | |||||
| float *d_nudge_max = nullptr; | |||||
| // Allocate space for device copies | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), channel_out_ * sizeof(float)), | |||||
| "Malloc gpu memory failed"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), channel_out_ * sizeof(float)), | |||||
| "Malloc gpu memory failed"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), channel_out_ * sizeof(float)), | |||||
| "Malloc gpu memory failed"); | |||||
| CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_, | CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_, | ||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, channel_out_, d_nudge_min, d_nudge_max, | CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, channel_out_, d_nudge_min, d_nudge_max, | ||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| // Cleanup | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed"); | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed"); | |||||
| } else { | } else { | ||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, gradient, input_size_, cudaMemcpyDeviceToDevice), | |||||
| CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)), | |||||
| "Copy gpu memory failed."); | "Copy gpu memory failed."); | ||||
| } | } | ||||
| global_step_++; | global_step_++; | ||||