| @@ -18,24 +18,24 @@ | |||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| __global__ void SigmoidCrossEntropyWithLogitsGradKernel(const size_t size, const T *logits, const S *labels, | __global__ void SigmoidCrossEntropyWithLogitsGradKernel(const size_t size, const T *logits, const S *labels, | ||||
| T *outputs) { | |||||
| const T *dout_addr, T *outputs) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { | ||||
| if (logits[i] >= 0) { | if (logits[i] >= 0) { | ||||
| outputs[i] = 1. / (1. + exp(-logits[i])) - labels[i]; | |||||
| outputs[i] = (1. / (1. + exp(-logits[i])) - labels[i]) * dout_addr[i]; | |||||
| } else { | } else { | ||||
| const T exp_val = exp(logits[i]); | const T exp_val = exp(logits[i]); | ||||
| outputs[i] = exp_val / (1. + exp_val) - labels[i]; | |||||
| outputs[i] = (exp_val / (1. + exp_val) - labels[i]) * dout_addr[i]; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, T *outputs, | |||||
| cudaStream_t cuda_stream) { | |||||
| void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, const T *dout_addr, | |||||
| T *outputs, cudaStream_t cuda_stream) { | |||||
| SigmoidCrossEntropyWithLogitsGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, logits, labels, | SigmoidCrossEntropyWithLogitsGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, logits, labels, | ||||
| outputs); | |||||
| dout_addr, outputs); | |||||
| } | } | ||||
| template void SigmoidCrossEntropyWithLogitsGrad<float, float>(const size_t size, const float *logits, | template void SigmoidCrossEntropyWithLogitsGrad<float, float>(const size_t size, const float *logits, | ||||
| const float *labels, float *outputs, | |||||
| cudaStream_t cuda_stream); | |||||
| const float *labels, const float *dout_addr, | |||||
| float *outputs, cudaStream_t cuda_stream); | |||||
| @@ -19,7 +19,7 @@ | |||||
| #include "runtime/device/gpu/cuda_common.h" | #include "runtime/device/gpu/cuda_common.h" | ||||
| template <typename T, typename S> | template <typename T, typename S> | ||||
| void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, T *outputs, | |||||
| cudaStream_t cuda_stream); | |||||
| void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const S *labels, const T *dout_addr, | |||||
| T *outputs, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ | #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_ | ||||
| @@ -38,9 +38,10 @@ class SigmoidCrossEntropyWithLogitsGradGpuKernel : public GpuKernel { | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | ||||
| T *logits_addr = GetDeviceAddress<T>(inputs, 0); | T *logits_addr = GetDeviceAddress<T>(inputs, 0); | ||||
| S *labels_addr = GetDeviceAddress<S>(inputs, 1); | S *labels_addr = GetDeviceAddress<S>(inputs, 1); | ||||
| T *dout_addr = GetDeviceAddress<T>(inputs, 2); | |||||
| T *outputs_addr = GetDeviceAddress<T>(outputs, 0); | T *outputs_addr = GetDeviceAddress<T>(outputs, 0); | ||||
| SigmoidCrossEntropyWithLogitsGrad(inputs[0]->size / sizeof(T), logits_addr, labels_addr, outputs_addr, | |||||
| SigmoidCrossEntropyWithLogitsGrad(inputs[0]->size / sizeof(T), logits_addr, labels_addr, dout_addr, outputs_addr, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -78,6 +79,7 @@ class SigmoidCrossEntropyWithLogitsGradGpuKernel : public GpuKernel { | |||||
| void InitSizeLists() override { | void InitSizeLists() override { | ||||
| input_size_list_.push_back(logits_size_); | input_size_list_.push_back(logits_size_); | ||||
| input_size_list_.push_back(labels_size_); | input_size_list_.push_back(labels_size_); | ||||
| input_size_list_.push_back(logits_size_); | |||||
| output_size_list_.push_back(outputs_size_); | output_size_list_.push_back(outputs_size_); | ||||
| } | } | ||||