Browse Source

added GetReductionInt to common_utils.h and replaced duplicated code in all loss with reduction gpu op kernels (nll loss, kl div loss, and binary cross entropy)

tags/v1.6.0
markuskunej 4 years ago
parent
commit
abdba421e5
8 changed files with 27 additions and 41 deletions
  1. +11
    -0
      mindspore/ccsrc/backend/kernel_compiler/common_utils.cc
  2. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/common_utils.h
  3. +2
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h
  4. +2
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.h
  5. +2
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h
  6. +2
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h
  7. +3
    -11
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_gpu_kernel.h
  8. +4
    -10
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.h

+ 11
- 0
mindspore/ccsrc/backend/kernel_compiler/common_utils.cc View File

@@ -571,6 +571,17 @@ int Sign(float x) {
return 0;
}

int GetReductionInt(const std::string &reduction) {
if (reduction == "none") {
return 0;
} else if (reduction == "sum") {
return 2;
} else {
// reduction = 'mean'
return 1;
}
}

std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);



+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/common_utils.h View File

@@ -89,6 +89,7 @@ std::string GetProcessor(const AnfNodePtr &anf_node);
Processor GetProcessor(const string &processor);
bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b);
int Sign(float x);
int GetReductionInt(const std::string &reduction);
std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index);
std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list);


+ 2
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h View File

@@ -22,6 +22,7 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
namespace kernel {
@@ -67,11 +68,7 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel {
input_size_ *= input_shape[i];
}
string reduction = GetAttr<string>(kernel_node, "reduction");
if (reduction == "none") {
reduction_ = 0;
} else if (reduction == "sum") {
reduction_ = 2;
}
reduction_ = GetReductionInt(reduction);
workspace_size_ = sizeof(T);
if (reduction_ != 0) {
workspace_size_ *= input_size_;


+ 2
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_grad_kernel.h View File

@@ -22,6 +22,7 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
namespace kernel {
@@ -69,11 +70,7 @@ class BinaryCrossEntropyGradGpuKernel : public GpuKernel {
input_size_ *= input_shape[i];
}
string reduction = GetAttr<string>(kernel_node, "reduction");
if (reduction == "none") {
reduction_ = 0;
} else if (reduction == "sum") {
reduction_ = 2;
}
reduction_ = GetReductionInt(reduction);
InitSizeLists();
return true;
}


+ 2
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h View File

@@ -22,6 +22,7 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
namespace kernel {
@@ -60,11 +61,7 @@ class KLDivLossGpuKernel : public GpuKernel {
input_size_ *= input_shape[i];
}
string reduction = GetAttr<string>(kernel_node, "reduction");
if (reduction == "none") {
reduction_ = 0;
} else if (reduction == "sum") {
reduction_ = 2;
}
reduction_ = GetReductionInt(reduction);
workspace_size_ = sizeof(T);
if (reduction_ == 0) {
workspace_size_ *= input_size_;


+ 2
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h View File

@@ -22,6 +22,7 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
namespace kernel {
@@ -61,11 +62,7 @@ class KLDivLossGradGpuKernel : public GpuKernel {
input_size_ *= input_shape[i];
}
string reduction = GetAttr<string>(kernel_node, "reduction");
if (reduction == "none") {
reduction_ = 0;
} else if (reduction == "sum") {
reduction_ = 2;
}
reduction_ = GetReductionInt(reduction);
InitSizeLists();
return true;
}


+ 3
- 11
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_gpu_kernel.h View File

@@ -22,6 +22,7 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh"
#include "backend/kernel_compiler/common_utils.h"

namespace mindspore {
namespace kernel {
@@ -60,19 +61,10 @@ class NLLLossGpuKernel : public GpuKernel {
input_size_ *= input_shape[i];
}
string reduction = GetAttr<string>(kernel_node, "reduction");

// if reduction is not 'none', tmp_nll is (N,) size
if (reduction == "none") {
reduction_ = 0;
} else if (reduction == "sum") {
reduction_ = 2;
tmp_loss_size_ = sizeof(T) * n_;
} else {
// reduction = 'mean'
reduction_ = 1;
reduction_ = GetReductionInt(reduction);
if ((reduction_ == 2) || (reduction_ == 1)) {
tmp_loss_size_ = sizeof(T) * n_;
}

tmp_target_weight_size_ = n_ * sizeof(S);

InitSizeLists();


+ 4
- 10
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/nll_loss_grad_gpu_kernel.h View File

@@ -22,6 +22,7 @@
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh"
#include "backend/kernel_compiler/common_utils.h"

namespace mindspore {
namespace kernel {
@@ -59,16 +60,9 @@ class NLLLossGradGpuKernel : public GpuKernel {
input_size_ *= input_shape[i];
}
string reduction = GetAttr<string>(kernel_node, "reduction");

// if reduction is not 'none', tmp_nll is (N,) size
if (reduction == "none") {
reduction_ = 0;
num_dloss_ = n_; // dloss is a vector
} else if (reduction == "sum") {
reduction_ = 2;
} else {
// reduction = 'mean'
reduction_ = 1;
reduction_ = GetReductionInt(reduction);
if (reduction_ == 0) {
num_dloss_ = n_;
}

InitSizeLists();


Loading…
Cancel
Save