Browse Source

Add protection in cross entropy kernel.

tags/v0.5.0-beta
ZPaC 5 years ago
parent
commit
4814e78162
1 changed files with 3 additions and 2 deletions
  1. +3
    -2
      mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu

+ 3
- 2
mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu View File

@@ -27,7 +27,7 @@ __global__ void CrossEntropyWithSparseKernel(const T *logits, const S *labels, c
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
T logit = logits[i * class_num + labels[i]]; T logit = logits[i * class_num + labels[i]];
if (logit <= 0) { if (logit <= 0) {
logit += epsilon;
logit = epsilon;
} }
total_loss += -logf(logit); total_loss += -logf(logit);
} }
@@ -54,8 +54,9 @@ __global__ void CrossEntropyGradWithSparseKernel(const T *logits, const S *label
template <typename T, typename S> template <typename T, typename S>
__global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_t class_num, T *losses, T *dlogits) { __global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_t class_num, T *losses, T *dlogits) {
losses[threadIdx.x] = 0; losses[threadIdx.x] = 0;
T epsilon = 1e-6;
for (int i = threadIdx.x * class_num; i < (threadIdx.x + 1) * class_num; ++i) { for (int i = threadIdx.x * class_num; i < (threadIdx.x + 1) * class_num; ++i) {
losses[threadIdx.x] -= logf(logits[i]) * labels[i];
losses[threadIdx.x] -= logf((logits[i] <= 0 ? epsilon : logits[i])) * labels[i];
dlogits[i] = logits[i] - labels[i]; dlogits[i] = logits[i] - labels[i];
} }
} }


Loading…
Cancel
Save