From 8fa4422dac895e5c68ea4d3f7dd224e62f79df19 Mon Sep 17 00:00:00 2001 From: tom__chen Date: Fri, 21 Aug 2020 12:54:40 -0800 Subject: [PATCH] fix non-sparse cross entropy gpu kernel fix white space --- .../gpu/cuda_impl/cross_entropy_impl.cu | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cu index 987cd1adde..21875eed70 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cross_entropy_impl.cu @@ -52,12 +52,18 @@ __global__ void CrossEntropyGradWithSparseKernel(const T *logits, const S *label } template -__global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_t class_num, T *losses, T *dlogits) { - losses[threadIdx.x] = 0; - T epsilon = 1e-6; - for (int i = threadIdx.x * class_num; i < (threadIdx.x + 1) * class_num; ++i) { - losses[threadIdx.x] -= logf((logits[i] <= 0 ? epsilon : logits[i])) * labels[i]; - dlogits[i] = logits[i] - labels[i]; +__global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, + T epsilon, T *losses, T *dlogits) { + for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; + index < batch_size; + index += blockDim.x * gridDim.x) { + losses[index] = 0; + const int start = index * class_num; + const int end = (index + 1) * class_num; + for (int i = start; i < end; ++i) { + losses[index] -= logf((logits[i] <= 0 ? epsilon : logits[i])) * labels[i]; + dlogits[i] = logits[i] - labels[i]; + } } } @@ -79,7 +85,9 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b template void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses, T *dlogits, cudaStream_t cuda_stream) { - CrossEntropyKernel<<<1, batch_size, 0, cuda_stream>>>(logits, labels, class_num, losses, dlogits); + T epsilon = 1e-6; + CrossEntropyKernel<<>>(logits, labels, batch_size, class_num, + epsilon, losses, dlogits); } template void CrossEntropyWithSparse(const float *logits, const int *labels, const size_t batch_size,