From d667d6ee92339b3f93f0a8b280ad6b4e68fdac71 Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Mon, 17 Aug 2020 15:45:43 +0800 Subject: [PATCH] bugfix:SigmoidCrossEntropyWithLogitsGrad need multiply dout --- ...igmoid_cross_entropy_with_logits_grad_impl.cu | 16 ++++++++-------- ...gmoid_cross_entropy_with_logits_grad_impl.cuh | 4 ++-- ...d_cross_entropy_with_logits_grad_gpu_kernel.h | 4 +++- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu index f0c64bfb01..e83dbff060 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu @@ -18,24 +18,24 @@ template __global__ void SigmoidCrossEntropyWithLogitsGradKernel(const size_t size, const T *logits, const S *labels, - T *outputs) { + const T *dout_addr, T *outputs) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { if (logits[i] >= 0) { - outputs[i] = 1. / (1. + exp(-logits[i])) - labels[i]; + outputs[i] = (1. / (1. + exp(-logits[i])) - labels[i]) * dout_addr[i]; } else { const T exp_val = exp(logits[i]); - outputs[i] = exp_val / (1. + exp_val) - labels[i]; + outputs[i] = (exp_val / (1. + exp_val) - labels[i]) * dout_addr[i]; } } } template -void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, T *outputs, - cudaStream_t cuda_stream) { +void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, const T *dout_addr, + T *outputs, cudaStream_t cuda_stream) { SigmoidCrossEntropyWithLogitsGradKernel<<>>(size, logits, labels, - outputs); + dout_addr, outputs); } template void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const float *logits, - const float *labels, float *outputs, - cudaStream_t cuda_stream); + const float *labels, const float *dout_addr, + float *outputs, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh index 6b444d6c02..d9a1c6f1df 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh @@ -19,7 +19,7 @@ #include "runtime/device/gpu/cuda_common.h" template -void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, T *outputs, - cudaStream_t cuda_stream); +void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, const T *dout_addr, + T *outputs, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h index 873f9c5be1..abbc4456f4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h @@ -38,9 +38,10 @@ class SigmoidCrossEntropyWithLogitsGradGpuKernel : public GpuKernel { const std::vector &outputs, void *stream_ptr) override { T *logits_addr = GetDeviceAddress(inputs, 0); S *labels_addr = GetDeviceAddress(inputs, 1); + T *dout_addr = GetDeviceAddress(inputs, 2); T *outputs_addr = GetDeviceAddress(outputs, 0); - SigmoidCrossEntropyWithLogitsGrad(inputs[0]->size / sizeof(T), logits_addr, labels_addr, outputs_addr, + SigmoidCrossEntropyWithLogitsGrad(inputs[0]->size / sizeof(T), logits_addr, labels_addr, dout_addr, outputs_addr, reinterpret_cast(stream_ptr)); return true; } @@ -78,6 +79,7 @@ class SigmoidCrossEntropyWithLogitsGradGpuKernel : public GpuKernel { void InitSizeLists() override { input_size_list_.push_back(logits_size_); input_size_list_.push_back(labels_size_); + input_size_list_.push_back(logits_size_); output_size_list_.push_back(outputs_size_); }