diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu index edf1929261..544bc7e80e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu @@ -62,13 +62,9 @@ __global__ void KLDivLossKernel(const int input_size, const int reduction, const } template -void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, T *loss, +void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, T *loss, T *tmp_loss, cudaStream_t stream) { LossInitKernel<<<1, 1, 0, stream>>>(loss); - T *tmp_loss; - if (reduction != 0) { - cudaMalloc(reinterpret_cast(&tmp_loss), input_size * sizeof(T)); - } KLDivLossKernel<<>>(input_size, reduction, input_x, input_y, loss, tmp_loss); if (reduction != 0) { @@ -83,7 +79,6 @@ void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, co } Copy<<<1, 1, 0, stream>>>(loss, tmp_loss, reduction, input_size); } - cudaFree(tmp_loss); } template @@ -119,19 +114,17 @@ void KLDivLossGrad(const int &input_size, const int &reduction, const T *input_x template __global__ void BinaryCrossEntropyLossKernel(const int input_size, const int reduction, const T *input_x, const T *input_y, const T *weight, T *loss, T *tmp_loss) { - T epsilon = 1e-6; + T epsilon = 1e-12; if (reduction == 0) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { - T antilogarithm = max(input_x[i], epsilon); - T antilogarithm2 = min(1 - input_x[i], 1 - epsilon); - T value = -weight[i] * (input_y[i] * logf(antilogarithm) + (1 - input_y[i]) * logf(antilogarithm2)); + T value = + -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); loss[i] = value; } } else { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { - T antilogarithm = max(input_x[i], epsilon); - T antilogarithm2 = min(1 - input_x[i], 1 - epsilon); - T value = -weight[i] * (input_y[i] * logf(antilogarithm) + (1 - input_y[i]) * logf(antilogarithm2)); + T value = + -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); tmp_loss[i] = value; } } @@ -139,12 +132,8 @@ __global__ void BinaryCrossEntropyLossKernel(const int input_size, const int red template void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, - const T *weight, T *loss, cudaStream_t stream) { + const T *weight, T *loss, T *tmp_loss, cudaStream_t stream) { LossInitKernel<<<1, 1, 0, stream>>>(loss); - T *tmp_loss; - if (reduction != 0) { - cudaMalloc(reinterpret_cast(&tmp_loss), input_size * sizeof(T)); - } BinaryCrossEntropyLossKernel<<>>(input_size, reduction, input_x, input_y, weight, loss, tmp_loss); if (reduction != 0) { @@ -159,13 +148,12 @@ void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const T } Copy<<<1, 1, 0, stream>>>(loss, tmp_loss, reduction, input_size); } - cudaFree(tmp_loss); } template __global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const int reduction, const T *input_x, const T *input_y, const T *weight, const T *dloss, T *dx) { - T epsilon = 1e-6; + T epsilon = 1e-12; if (reduction == 0) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { T denominator = max(input_x[i] * (1 - input_x[i]), epsilon); @@ -193,13 +181,14 @@ void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, con } template void KLDivLoss(const int &input_size, const int &reduction, const float *input_x, const float *input_y, - float *loss, cudaStream_t stream); + float *loss, float *tmp_loss, cudaStream_t stream); template void KLDivLossGrad(const int &input_size, const int &reduction, const float *input_x, const float *input_y, const float *dloss, float *dx, float *dy, cudaStream_t stream); template void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const float *input_x, - const float *input_y, const float *weight, float *loss, cudaStream_t stream); + const float *input_y, const float *weight, float *loss, float *tmp_loss, + cudaStream_t stream); template void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, const float *input_x, const float *input_y, const float *weight, const float *dloss, float *dx, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh index a01ca830f7..06f5350918 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh @@ -18,12 +18,12 @@ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_LOSS_WITH_REDUCTION_IMPL_CUH template void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, - const T *weight, T *loss, cudaStream_t stream); + const T *weight, T *loss, T *tmp_loss, cudaStream_t stream); template void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y, const T *weight, const T *dloss, T *dx, cudaStream_t stream); template -void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, T *loss, +void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, T *loss, T *tmp_loss, cudaStream_t stream); template void KLDivLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y, const T *dloss, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h index 8ccbc22d68..00b9f2fbbf 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/binary_cross_entropy_gpu_kernel.h @@ -30,19 +30,17 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel { public: BinaryCrossEntropyGpuKernel() : input_size_(1), reduction_(1) {} ~BinaryCrossEntropyGpuKernel() override = default; - const std::vector &GetInputSizeList() const override { return input_size_list_; } const std::vector &GetOutputSizeList() const override { return output_size_list_; } const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - - bool Launch(const std::vector &inputs, const std::vector &, + bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { T *input_x = GetDeviceAddress(inputs, 0); T *input_y = GetDeviceAddress(inputs, 1); T *weight = GetDeviceAddress(inputs, 2); T *loss = GetDeviceAddress(outputs, 0); - - BinaryCrossEntropyLoss(input_size_, reduction_, input_x, input_y, weight, loss, + T *tmp_loss = GetDeviceAddress(workspace, 0); + BinaryCrossEntropyLoss(input_size_, reduction_, input_x, input_y, weight, loss, tmp_loss, reinterpret_cast(stream_ptr)); return true; } @@ -52,13 +50,16 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel { for (size_t i = 0; i < input_shape.size(); i++) { input_size_ *= input_shape[i]; } - string reduction = GetAttr(kernel_node, "reduction"); if (reduction == "none") { reduction_ = 0; } else if (reduction == "sum") { reduction_ = 2; } + workspace_size_ = sizeof(T); + if (reduction_ == 0) { + workspace_size_ *= input_size_; + } InitSizeLists(); return true; } @@ -73,12 +74,13 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel { } else { output_size_list_.push_back(sizeof(T)); } + workspace_size_list_.push_back(workspace_size_); } private: size_t input_size_; int reduction_; - + size_t workspace_size_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h index 43aced9494..74eb340a3e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h @@ -35,13 +35,13 @@ class KLDivLossGpuKernel : public GpuKernel { const std::vector &GetOutputSizeList() const override { return output_size_list_; } const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &, + bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { T *input_x = GetDeviceAddress(inputs, 0); T *input_y = GetDeviceAddress(inputs, 1); T *loss = GetDeviceAddress(outputs, 0); - - KLDivLoss(input_size_, reduction_, input_x, input_y, loss, reinterpret_cast(stream_ptr)); + T *tmp_loss = GetDeviceAddress(workspace, 0); + KLDivLoss(input_size_, reduction_, input_x, input_y, loss, tmp_loss, reinterpret_cast(stream_ptr)); return true; } @@ -50,13 +50,16 @@ class KLDivLossGpuKernel : public GpuKernel { for (size_t i = 0; i < input_shape.size(); i++) { input_size_ *= input_shape[i]; } - string reduction = GetAttr(kernel_node, "reduction"); if (reduction == "none") { reduction_ = 0; } else if (reduction == "sum") { reduction_ = 2; } + workspace_size_ = sizeof(T); + if (reduction_ == 0) { + workspace_size_ *= input_size_; + } InitSizeLists(); return true; } @@ -70,12 +73,13 @@ class KLDivLossGpuKernel : public GpuKernel { } else { output_size_list_.push_back(sizeof(T)); } + workspace_size_list_.push_back(workspace_size_); } private: size_t input_size_; int reduction_; - + size_t workspace_size_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/mindspore/nn/probability/distribution/categorical.py b/mindspore/nn/probability/distribution/categorical.py index feb1778da9..0c41e717f1 100644 --- a/mindspore/nn/probability/distribution/categorical.py +++ b/mindspore/nn/probability/distribution/categorical.py @@ -72,6 +72,7 @@ class Categorical(Distribution): dtype=mstype.int32, name="Categorical"): param = dict(locals()) + param['param_dict'] = {'probs': probs, 'logits': logits} valid_dtype = mstype.int_type check_type(dtype, valid_dtype, "Categorical") super(Categorical, self).__init__(seed, dtype, name, param)