| @@ -1,47 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <stdint.h> | |||
| #include "cross_entropy_cuda_impl.cuh" | |||
| #include "include/cuda_runtime.h" | |||
| __global__ void CalCrossEntropyWithGradKernel(const float *softmax_logits, const float *log_softmax_logits, | |||
| const float *labels, const int batch_size, const int num_classes, | |||
| float *loss, float *dx) { | |||
| extern __shared__ float loss_shared[]; | |||
| const float mean_scale = 1.0f / static_cast<float>(batch_size); | |||
| loss_shared[threadIdx.x] = 0; | |||
| for (int i = threadIdx.x * num_classes; i < (threadIdx.x + 1) * num_classes; ++i) { | |||
| loss_shared[threadIdx.x] -= log_softmax_logits[i] * labels[i]; | |||
| dx[i] = (softmax_logits[i] - labels[i]) * mean_scale; | |||
| } | |||
| __syncthreads(); | |||
| if (threadIdx.x == 0) { | |||
| *loss = 0; | |||
| for (int i = 0; i < batch_size; i++) { | |||
| *loss += loss_shared[i]; | |||
| } | |||
| *loss *= mean_scale; | |||
| } | |||
| } | |||
| void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels, | |||
| const int batch_size, const int num_classes, float *loss, float *dx, | |||
| cudaStream_t cuda_stream) { | |||
| CalCrossEntropyWithGradKernel<<<1, batch_size, batch_size * sizeof(float), cuda_stream>>>( | |||
| softmax_logits, log_softmax_logits, labels, batch_size, num_classes, loss, dx); | |||
| } | |||
| @@ -1,26 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_ | |||
| #include "device/gpu/cuda_common.h" | |||
| void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels, | |||
| const int batch_size, const int num_classes, float *loss, float *dx, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_ | |||
| @@ -52,38 +52,12 @@ __global__ void CrossEntropyGradWithSparseKernel(const T *logits, const S *label | |||
| } | |||
| template <typename T, typename S> | |||
| __global__ void CrossEntropyWithoutSparseKernel(const T *logits, const S *labels, const size_t batch_size, | |||
| const size_t class_num, T *losses) { | |||
| T epsilon = 1e-6; | |||
| for (size_t i = 0; i < batch_size; ++i) { | |||
| T logit = 0.0; | |||
| for (size_t j = 0; j < class_num; j++) { | |||
| if (fabs(labels[i * class_num + j] - 1.0) <= 1e-8) { | |||
| logit = logits[i * class_num + j]; | |||
| break; | |||
| } | |||
| } | |||
| if (logit <= 0) { | |||
| logit += epsilon; | |||
| } | |||
| losses[i] = -logf(logit); | |||
| __global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_t class_num, T *losses, T *dlogits) { | |||
| losses[threadIdx.x] = 0; | |||
| for (int i = threadIdx.x * class_num; i < (threadIdx.x + 1) * class_num; ++i) { | |||
| losses[threadIdx.x] -= logf(logits[i]) * labels[i]; | |||
| dlogits[i] = logits[i] - labels[i]; | |||
| } | |||
| return; | |||
| } | |||
| template <typename T, typename S> | |||
| __global__ void CrossEntropyGradWithoutSparseKernel(const T *logits, const S *labels, const size_t batch_size, | |||
| const size_t class_num, T *grad) { | |||
| for (size_t i = 0; i < batch_size; i++) { | |||
| for (size_t j = blockIdx.x * blockDim.x + threadIdx.x; j < class_num; j += blockDim.x * gridDim.x) { | |||
| if (fabs(labels[i * class_num + j] - 1.0) <= 1e-8) { | |||
| grad[i * class_num + j] = (logits[i * class_num + j] - 1) / batch_size; | |||
| } else { | |||
| grad[i * class_num + j] = logits[i * class_num + j] / batch_size; | |||
| } | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| template <typename T, typename S> | |||
| @@ -102,18 +76,9 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b | |||
| } | |||
| template <typename T, typename S> | |||
| void CrossEntropyWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, | |||
| T *losses, cudaStream_t cuda_stream) { | |||
| CrossEntropyWithoutSparseKernel<<<1, 1, 0, cuda_stream>>>(logits, labels, batch_size, class_num, losses); | |||
| return; | |||
| } | |||
| template <typename T, typename S> | |||
| void CrossEntropyGradWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, | |||
| T *grad, cudaStream_t cuda_stream) { | |||
| CrossEntropyGradWithoutSparseKernel<<<GET_BLOCKS(class_num), GET_THREADS, 0, cuda_stream>>>( | |||
| logits, labels, batch_size, class_num, grad); | |||
| return; | |||
| void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses, | |||
| T *dlogits, cudaStream_t cuda_stream) { | |||
| CrossEntropyKernel<<<1, batch_size, 0, cuda_stream>>>(logits, labels, class_num, losses, dlogits); | |||
| } | |||
| template void CrossEntropyWithSparse<float, int>(const float *logits, const int *labels, const size_t batch_size, | |||
| @@ -126,8 +91,6 @@ template void CrossEntropyGradWithSparse<float, int>(const float *logits, const | |||
| template void CrossEntropyGradWithSparse<float, int64_t>(const float *logits, const int64_t *labels, | |||
| const size_t batch_size, const size_t class_num, float *grad, | |||
| cudaStream_t cuda_stream); | |||
| template void CrossEntropyWithoutSparse<float, float>(const float *logits, const float *labels, const size_t batch_size, | |||
| const size_t class_num, float *losses, cudaStream_t cuda_stream); | |||
| template void CrossEntropyGradWithoutSparse<float, float>(const float *logits, const float *labels, | |||
| const size_t batch_size, const size_t class_num, float *grad, | |||
| cudaStream_t cuda_stream); | |||
| template void CrossEntropy<float, float>(const float *logits, const float *labels, const size_t batch_size, | |||
| const size_t class_num, float *losses, float *dlogits, | |||
| cudaStream_t cuda_stream); | |||
| @@ -28,11 +28,6 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b | |||
| T *grad, cudaStream_t cuda_stream); | |||
| template <typename T, typename S> | |||
| void CrossEntropyWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, | |||
| T *losses, cudaStream_t cuda_stream); | |||
| template <typename T, typename S> | |||
| void CrossEntropyGradWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, | |||
| T *grad, cudaStream_t cuda_stream); | |||
| void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses, | |||
| T *dlogits, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_ | |||
| @@ -58,8 +58,8 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { | |||
| } | |||
| T *logits_addr = GetDeviceAddress<T>(inputs, 0); | |||
| S *labels_addr = GetDeviceAddress<S>(inputs, 1); | |||
| T *output1_addr = GetDeviceAddress<T>(outputs, 0); | |||
| T *output2_addr = GetDeviceAddress<T>(outputs, 1); | |||
| T *loss_addr = GetDeviceAddress<T>(outputs, 0); | |||
| T *dlogits_addr = GetDeviceAddress<T>(outputs, 1); | |||
| T *softmax_output_logits = GetDeviceAddress<T>(workspace, 0); | |||
| const float alpha = 1; | |||
| @@ -69,10 +69,8 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { | |||
| softmax_output_descriptor_, softmax_output_logits), | |||
| "cudnnSoftmaxForward failed."); | |||
| CrossEntropyWithoutSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output1_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CrossEntropyGradWithoutSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output2_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| CrossEntropy(softmax_output_logits, labels_addr, batch_size_, channel_size_, loss_addr, dlogits_addr, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||