diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cu index 7b09256e1d..f25727f2c3 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cu @@ -20,8 +20,8 @@ #include "device/gpu/cuda_common.h" #include "fake_quant_impl.cuh" -__global__ void FakeQuantize(const float* input, float* output, const int size, const float* nudge_min, - const float* nudge_max, const float* scale, bool symmetric) { +__global__ void FakeQuantize(const float *input, float *output, const int size, const float *nudge_min, + const float *nudge_max, const float *scale, bool symmetric) { float input_x = 0.f; int nudge_input = 0; @@ -43,8 +43,8 @@ __global__ void FakeQuantize(const float* input, float* output, const int size, return; } -__global__ void FakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size, - const float* nudge_min, const float* nudge_max) { +__global__ void FakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size, + const float *nudge_min, const float *nudge_max) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { if (input[i] < nudge_min[0] || input[i] > nudge_max[0]) { output[i] = 0; @@ -55,15 +55,18 @@ __global__ void FakeQuantizeGrad(const float* input, const float* gradient, floa return; } -__global__ void NudgeMinMax(const float* input_min, const float* input_max, const float quant_min, - const float quant_max, float* nudge_min, float* nudge_max, float* scale) { +__global__ void NudgeMinMax(const float *input_min, const float *input_max, const float quant_min, + const float quant_max, float *nudge_min, float *nudge_max, float *scale) { float zp_from_min = 0.f; - if ((quant_max - quant_min) == 0 || (*input_max - *input_min) == 0) { - *scale = 0.f; + scale[0] = 0.f; + nudge_max[0] = 0.f; + nudge_min[0] = 0.f; + if ((quant_max - quant_min) == 0 || (input_max[0] - input_min[0]) == 0) { + scale[0] = 0.f; zp_from_min = 0.f; } else { - *scale = (*input_max - *input_min) / (quant_max - quant_min); - zp_from_min = quant_min - *input_min / *scale; + scale[0] = (input_max[0] - input_min[0]) / (quant_max - quant_min); + zp_from_min = quant_min - input_min[0] / scale[0]; } float nudge_zp = 0.f; @@ -75,59 +78,59 @@ __global__ void NudgeMinMax(const float* input_min, const float* input_max, cons nudge_zp = round(zp_from_min); } - *nudge_min = (quant_min - nudge_zp) * (*scale); - *nudge_max = (quant_max - nudge_zp) * (*scale); + nudge_min[0] = (quant_min - nudge_zp) * (scale[0]); + nudge_max[0] = (quant_max - nudge_zp) * (scale[0]); return; } -__global__ void UpdateInputMinMaxWithEMA(float* input_min, float* input_max, const float min, const float max, +__global__ void UpdateInputMinMaxWithEMA(float *input_min, float *input_max, const float min, const float max, const float decay) { - *input_min = decay * (min) + (1 - decay) * (*input_min); - *input_min = *input_min > 0 ? 0 : *input_min; - *input_max = decay * (max) + (1 - decay) * (*input_max); - *input_max = *input_max < 0 ? 0 : *input_max; + input_min[0] = decay * (min) + (1 - decay) * (input_min[0]); + input_min[0] = input_min[0] > 0 ? 0 : input_min[0]; + input_max[0] = decay * (max) + (1 - decay) * (input_max[0]); + input_max[0] = input_max[0] < 0 ? 0 : input_max[0]; return; } -__global__ void UpdateInputMinMax(float* input_min, float* input_max, const float min, const float max) { - *input_min = min; - *input_max = max; +__global__ void UpdateInputMinMax(float *input_min, float *input_max, const float min, const float max) { + input_min[0] = min > 0 ? 0 : min; + input_max[0] = max < 0 ? 0 : max; } -void CalFakeQuantize(const float* input, float* output, const int size, const float* nudge_min, const float* nudge_max, - const float* scale, bool symmetric, cudaStream_t cuda_stream) { +void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max, + const float *scale, bool symmetric, cudaStream_t cuda_stream) { FakeQuantize<<>>(input, output, size, nudge_min, nudge_max, scale, symmetric); return; } -void CalFakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size, - const float* nudge_min, const float* nudge_max, cudaStream_t cuda_stream) { +void CalFakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size, + const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream) { FakeQuantizeGrad<<>>(input, gradient, output, size, nudge_min, nudge_max); return; } -void CalNudge(const float* input_min, const float* input_max, const float quant_min, const float quant_max, - float* nudge_min, float* nudge_max, float* scale, cudaStream_t cuda_stream) { - NudgeMinMax<<<1, 1>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale); +void CalNudge(const float *input_min, const float *input_max, const float quant_min, const float quant_max, + float *nudge_min, float *nudge_max, float *scale, cudaStream_t cuda_stream) { + NudgeMinMax<<<1, 1, 0, cuda_stream>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale); return; } -void CalMinMax(float* input, float* input_min, float* input_max, const int size, const float ema_decay, const bool ema, +void CalMinMax(float *input, float *input_min, float *input_max, const int size, const float ema_decay, const bool ema, cudaStream_t cuda_stream) { float minel = 0.f; float maxel = 0.f; + auto policy = thrust::cuda::par.on(cuda_stream); thrust::pair, thrust::device_ptr> tuple; - tuple = thrust::minmax_element(thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + size); + tuple = thrust::minmax_element(policy, thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + size); minel = tuple.first[0]; maxel = tuple.second[0]; if (ema) { - UpdateInputMinMaxWithEMA<<<1, 1>>>(input_min, input_max, minel, maxel, ema_decay); + UpdateInputMinMaxWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, minel, maxel, ema_decay); } else { - UpdateInputMinMax<<<1, 1>>>(input_min, input_max, minel, maxel); + UpdateInputMinMax<<<1, 1, 0, cuda_stream>>>(input_min, input_max, minel, maxel); } return; } - diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cuh index c88c1f79e2..27c39dead1 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cuh @@ -17,16 +17,16 @@ #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ -void CalFakeQuantize(const float* input, float* output, const int size, const float* nudge_min, const float* nudge_max, - const float* scale, bool symmetric, cudaStream_t cuda_stream); +void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max, + const float *scale, bool symmetric, cudaStream_t cuda_stream); -void CalFakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size, - const float* nudge_min, const float* nudge_max, cudaStream_t cuda_stream); +void CalFakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size, + const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream); -void CalNudge(const float* input_min, const float* input_max, const float quant_min, const float quant_max, - float* nudge_min, float* nudge_max, float* scale, cudaStream_t cuda_stream); +void CalNudge(const float *input_min, const float *input_max, const float quant_min, const float quant_max, + float *nudge_min, float *nudge_max, float *scale, cudaStream_t cuda_stream); -void CalMinMax(float* input, float* input_min, float* input_max, const int size, const float ema_decay, const bool ema, +void CalMinMax(float *input, float *input_min, float *input_max, const int size, const float ema_decay, const bool ema, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cu index 09153bf28f..b9aac9bdc3 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cu @@ -34,8 +34,8 @@ * @param channel_num * @return */ -__global__ void NudgeMinMaxPerChannel(const float* input_min, const float* input_max, const float quant_min, - const float quant_max, float* nudge_min, float* nudge_max, float* scale, +__global__ void NudgeMinMaxPerChannel(const float *input_min, const float *input_max, const float quant_min, + const float quant_max, float *nudge_min, float *nudge_max, float *scale, int channel_num) { float zp_from_min = 0.f; float nudge_zp = 0.f; @@ -62,8 +62,8 @@ __global__ void NudgeMinMaxPerChannel(const float* input_min, const float* input } } -void CalNudgePerChannel(const float* input_min, const float* input_max, const float quant_min, const float quant_max, - float* nudge_min, float* nudge_max, float* scale, const int channel_num, +void CalNudgePerChannel(const float *input_min, const float *input_max, const float quant_min, const float quant_max, + float *nudge_min, float *nudge_max, float *scale, const int channel_num, cudaStream_t cuda_stream) { NudgeMinMaxPerChannel<<>>( input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, channel_num); @@ -80,8 +80,8 @@ void CalNudgePerChannel(const float* input_min, const float* input_max, const fl * @param scale - array * @return */ -__global__ void FakeQuantizePerChannel(const float* input, float* output, const int total_size, const int channel_size, - const float* nudge_min, const float* nudge_max, const float* scale, +__global__ void FakeQuantizePerChannel(const float *input, float *output, const int total_size, const int channel_size, + const float *nudge_min, const float *nudge_max, const float *scale, bool symmetric) { float input_x = 0.f; int nudge_input = 0; @@ -106,8 +106,8 @@ __global__ void FakeQuantizePerChannel(const float* input, float* output, const } } -void CalFakeQuantizePerChannel(const float* input, float* output, const int total_size, const int channel_size, - const float* nudge_min, const float* nudge_max, const float* scale, bool symmetric, +void CalFakeQuantizePerChannel(const float *input, float *output, const int total_size, const int channel_size, + const float *nudge_min, const float *nudge_max, const float *scale, bool symmetric, cudaStream_t cuda_stream) { FakeQuantizePerChannel<<>>( input, output, total_size, channel_size, nudge_min, nudge_max, scale, symmetric); @@ -121,10 +121,10 @@ void CalFakeQuantizePerChannel(const float* input, float* output, const int tota * @param max * @return */ -__global__ void UpdateInputMinMaxPerChannel(float* input_min, float* input_max, float* input, int channels, +__global__ void UpdateInputMinMaxPerChannel(float *input_min, float *input_max, float *input, int channels, int per_channel_nums, bool ema, float ema_decay) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) { - thrust::pair sum = + thrust::pair sum = thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1)); if (ema) { input_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i]; @@ -133,25 +133,27 @@ __global__ void UpdateInputMinMaxPerChannel(float* input_min, float* input_max, input_min[i] = sum.first[0]; input_max[i] = sum.second[0]; } + input_min[i] = input_min[i] > 0 ? 0 : input_min[i]; + input_max[i] = input_max[i] < 0 ? 0 : input_max[i]; } } -__global__ void UpdateInputMinMaxPerChannelWithEMA(float* input_min, float* input_max, float min, float max, +__global__ void UpdateInputMinMaxPerChannelWithEMA(float *input_min, float *input_max, float min, float max, const float decay) { *input_min = decay * (min) + (1 - decay) * (*input_min); *input_max = decay * (max) + (1 - decay) * (*input_max); } -void CalMinMaxPerChannel(float* input, float* input_min, float* input_max, const int total_size, const int channel_size, +void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, const int total_size, const int channel_size, const float ema_decay, const bool ema, cudaStream_t cuda_stream) { int per_channel_num = total_size / channel_size; UpdateInputMinMaxPerChannel<<>>( input_min, input_max, input, channel_size, per_channel_num, ema, ema_decay); } -__global__ void FakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output, - const int total_size, const int channel_size, const float* nudge_min, - const float* nudge_max) { +__global__ void FakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output, + const int total_size, const int channel_size, const float *nudge_min, + const float *nudge_max) { int channel_idx = 0; int per_channel_num = total_size / channel_size; @@ -165,10 +167,9 @@ __global__ void FakeQuantizePerChannelGrad(const float* input, const float* grad } } -void CalFakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output, const int total_num, - const int channel_num, const float* nudge_min, const float* nudge_max, +void CalFakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output, const int total_num, + const int channel_num, const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream) { FakeQuantizePerChannelGrad<<>>( input, gradient, output, total_num, channel_num, nudge_min, nudge_max); } - diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h index beeeb12a9a..c1804a5b93 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h @@ -114,8 +114,7 @@ class BatchNormFold2GpuKernel : public GpuKernel { output_size_list_.push_back(input_size); - size_t workspace_size = 0; - workspace_size_list_.push_back(workspace_size); + workspace_size_list_.push_back(sizeof(int32_t)); } private: diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h index 099960e7fa..38adda718c 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h @@ -70,9 +70,12 @@ class BatchNormFold2GradGpuKernel : public GpuKernel { int32_t current_step_host[1]; size_t x_size = batch_size_ * channel_ * height_ * width_ * sizeof(T); - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(current_step_host, global_step, sizeof(int32_t), cudaMemcpyDeviceToHost), + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, global_step, sizeof(int32_t), cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr)), "Failed to copy gpu memory."); - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(d_x, dout, x_size, cudaMemcpyDeviceToDevice), "Failed to copy gpu memory."); + CHECK_CUDA_RET_WITH_ERROR( + cudaMemcpyAsync(d_x, dout, x_size, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), + "Failed to copy gpu memory."); BatchNormFold2GradReduce(dout, x, d_beta, tmp, reduce_x, tmp2, tmp_x, batch_size_, channel_, height_, width_, reinterpret_cast(stream_ptr)); diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.h index 3e8c1ca52b..a5a8a10dc0 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_gpu_kernel.h @@ -55,12 +55,13 @@ class BatchNormFoldGpuKernel : public GpuKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, uintptr_t stream_ptr) override { (void)workspace; - auto x = reinterpret_cast(inputs[0]->addr); - auto mean = reinterpret_cast(inputs[1]->addr); - auto variance = reinterpret_cast(inputs[2]->addr); - int *current_step = reinterpret_cast(inputs[3]->addr); + auto x = GetDeviceAddress(inputs, 0); + auto mean = GetDeviceAddress(inputs, 1); + auto variance = GetDeviceAddress(inputs, 2); + int *current_step = GetDeviceAddress(inputs, 3); int current_step_host[1]; - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost), + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr)), "Copy gpu memoy failed."); if (x == nullptr) { MS_LOG(ERROR) << "BatchNormFoldGpuKernel x is null."; @@ -78,15 +79,17 @@ class BatchNormFoldGpuKernel : public GpuKernel { MS_LOG(ERROR) << "BatchNormFoldGpuKernel current_step is null."; return false; } - auto batch_mean = reinterpret_cast(outputs[0]->addr); - auto batch_std = reinterpret_cast(outputs[1]->addr); - auto running_mean = reinterpret_cast(outputs[2]->addr); - auto running_std = reinterpret_cast(outputs[3]->addr); - auto y = reinterpret_cast(workspace[0]->addr); - - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(running_mean, mean, output_size_, cudaMemcpyDeviceToDevice), + auto batch_mean = GetDeviceAddress(outputs, 0); + auto batch_std = GetDeviceAddress(outputs, 1); + auto running_mean = GetDeviceAddress(outputs, 2); + auto running_std = GetDeviceAddress(outputs, 3); + auto y = GetDeviceAddress(workspace, 0); + + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(running_mean, mean, output_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), "Failed to copy gpu memory."); - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(running_std, variance, output_size_, cudaMemcpyDeviceToDevice), + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(running_std, variance, output_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), "Failed to copy gpu memory."); CalUpdateRunningStd(channel_, epsilon_, running_std, reinterpret_cast(stream_ptr)); if (!is_training_ || current_step_host[0] >= freeze_bn_) { diff --git a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h index ec845fbb9e..cc420781da 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h @@ -57,7 +57,8 @@ class BatchNormFoldGradGpuKernel : public GpuKernel { T *batch_std = GetDeviceAddress(inputs, 4); int *current_step = GetDeviceAddress(inputs, 5); int current_step_host[1]; - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost), + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost, + reinterpret_cast(stream_ptr)), "Copy gpu memoy failed."); if (d_batch_mean == nullptr) { MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel d_batch_mean is null."; @@ -83,7 +84,7 @@ class BatchNormFoldGradGpuKernel : public GpuKernel { MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel current_step is null."; return false; } - T *dx = reinterpret_cast(outputs[0]->addr); + T *dx = GetDeviceAddress(outputs, 0); if (!is_training_ || current_step_host[0] >= freeze_bn_) { ThrustFillWith(dx, batch_ * channel_ * height_ * width_, 0.f, reinterpret_cast(stream_ptr)); diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.cc index f4e2c74aac..ee1cb0d012 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.cc @@ -60,7 +60,7 @@ bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) { num_bits_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits")); ema_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema")); - ema_decay_ = 1.0 - GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); + ema_decay_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay")); training_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training")); if (num_bits_ <= 2 || num_bits_ >= 16) { @@ -115,7 +115,6 @@ void FakeQuantGpuKernel::InitSizeLists() { bool FakeQuantGpuKernel::Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, uintptr_t stream_ptr) { - (void)workspace; float *output = GetDeviceAddress(outputs, 0); float *input = GetDeviceAddress(inputs, 0); float *input_min = GetDeviceAddress(inputs, 1); @@ -151,7 +150,8 @@ bool FakeQuantGpuKernel::Launch(const std::vector &inputs, const std CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_, reinterpret_cast(stream_ptr)); } else { - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, input, input_size_, cudaMemcpyDeviceToDevice), + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), "Copy gpu memory failed"); } global_step_++; diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc index 4746e8e8e0..239e55b5b0 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc @@ -93,7 +93,6 @@ void FakeQuantGradGpuKernel::InitSizeLists() { bool FakeQuantGradGpuKernel::Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, uintptr_t stream_ptr) { - (void)workspace; float *output = GetDeviceAddress(outputs, 0); float *gradient = GetDeviceAddress(inputs, 0); float *input = GetDeviceAddress(inputs, 1); @@ -133,8 +132,9 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector &inputs, const CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed"); CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed"); } else { - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, gradient, input_size_, cudaMemcpyDeviceToDevice), - "Copy gpu memory failed."); + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), + "Copy gpu memory failed"); } global_step_++; return true; diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc index 1da9f457a1..c452bb5dd1 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc @@ -107,11 +107,13 @@ bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) { } void FakeQuantPerChannelGpuKernel::InitSizeLists() { - input_size_list_.push_back(input_size_); // input - input_size_list_.push_back(min_size_); // min - input_size_list_.push_back(max_size_); // max - output_size_list_.push_back(output_size_); - workspace_size_list_.push_back(workspace_size_); + input_size_list_.push_back(input_size_); // input in tensor + input_size_list_.push_back(min_size_); // min one scalar + input_size_list_.push_back(max_size_); // max on scalar + output_size_list_.push_back(output_size_); // output in tensor + workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel + workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel + workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel } void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForTraining(float *input, float *output, float *input_min, @@ -128,8 +130,9 @@ void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForTraining(float *input, floa CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max, d_scale, symmetric_, reinterpret_cast(stream_ptr)); } else { - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, input, input_size_, cudaMemcpyDeviceToDevice), - "Copy gpu memory failed."); + CHECK_CUDA_RET_WITH_ERROR( + cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), + "Copy gpu memory failed."); } global_step_++; } @@ -152,6 +155,9 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector &inputs, float *input = GetDeviceAddress(inputs, 0); float *input_min = GetDeviceAddress(inputs, 1); float *input_max = GetDeviceAddress(inputs, 2); + float *d_scale = GetDeviceAddress(workspace, 0); + float *d_nudge_min = GetDeviceAddress(workspace, 1); + float *d_nudge_max = GetDeviceAddress(workspace, 2); if (input == nullptr) { MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null."; @@ -160,27 +166,12 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector &inputs, MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input min or max is null."; } - // Allocate space for device copies - float *d_scale = nullptr; - float *d_nudge_min = nullptr; - float *d_nudge_max = nullptr; - CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast(&d_scale), sizeof(float) * channel_out_), - "Malloc gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast(&d_nudge_min), sizeof(float) * channel_out_), - "Malloc gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast(&d_nudge_max), sizeof(float) * channel_out_), - "Malloc gpu memory failed"); - if (training_) { CalFakeQuantizeForTraining(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr); } else { CalFakeQuantizeForInfer(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr); } - // Cleanup - CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed"); return true; } diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.cc index 3184132121..f995f81190 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.cc @@ -97,7 +97,9 @@ void FakeQuantPerChannelGradGpuKernel::InitSizeLists() { input_size_list_.push_back(min_size_); // min input_size_list_.push_back(max_size_); // max output_size_list_.push_back(output_size_); - workspace_size_list_.push_back(workspace_size_); + workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel + workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel + workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel } bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector &inputs, @@ -109,6 +111,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector &inp float *input = GetDeviceAddress(inputs, 1); float *input_min = GetDeviceAddress(inputs, 2); float *input_max = GetDeviceAddress(inputs, 3); + float *d_scale = GetDeviceAddress(workspace, 0); + float *d_nudge_min = GetDeviceAddress(workspace, 1); + float *d_nudge_max = GetDeviceAddress(workspace, 2); if (gradient == nullptr) { MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel gradient is null"; @@ -125,28 +130,13 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector &inp int total_size = input_size_ / sizeof(float); if (global_step_ >= quant_delay_) { - float *d_scale = nullptr; - float *d_nudge_min = nullptr; - float *d_nudge_max = nullptr; - // Allocate space for device copies - CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast(&d_scale), channel_out_ * sizeof(float)), - "Malloc gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast(&d_nudge_min), channel_out_ * sizeof(float)), - "Malloc gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast(&d_nudge_max), channel_out_ * sizeof(float)), - "Malloc gpu memory failed"); - CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_, reinterpret_cast(stream_ptr)); CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, channel_out_, d_nudge_min, d_nudge_max, reinterpret_cast(stream_ptr)); - - // Cleanup - CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed"); - CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed"); } else { - CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, gradient, input_size_, cudaMemcpyDeviceToDevice), + CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(stream_ptr)), "Copy gpu memory failed."); } global_step_++;