| @@ -62,13 +62,9 @@ __global__ void KLDivLossKernel(const int input_size, const int reduction, const | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, T *loss, | |||||
| void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, T *loss, T *tmp_loss, | |||||
| cudaStream_t stream) { | cudaStream_t stream) { | ||||
| LossInitKernel<<<1, 1, 0, stream>>>(loss); | LossInitKernel<<<1, 1, 0, stream>>>(loss); | ||||
| T *tmp_loss; | |||||
| if (reduction != 0) { | |||||
| cudaMalloc(reinterpret_cast<void **>(&tmp_loss), input_size * sizeof(T)); | |||||
| } | |||||
| KLDivLossKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, reduction, input_x, input_y, loss, | KLDivLossKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, reduction, input_x, input_y, loss, | ||||
| tmp_loss); | tmp_loss); | ||||
| if (reduction != 0) { | if (reduction != 0) { | ||||
| @@ -83,7 +79,6 @@ void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, co | |||||
| } | } | ||||
| Copy<<<1, 1, 0, stream>>>(loss, tmp_loss, reduction, input_size); | Copy<<<1, 1, 0, stream>>>(loss, tmp_loss, reduction, input_size); | ||||
| } | } | ||||
| cudaFree(tmp_loss); | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| @@ -119,19 +114,17 @@ void KLDivLossGrad(const int &input_size, const int &reduction, const T *input_x | |||||
| template <typename T> | template <typename T> | ||||
| __global__ void BinaryCrossEntropyLossKernel(const int input_size, const int reduction, const T *input_x, | __global__ void BinaryCrossEntropyLossKernel(const int input_size, const int reduction, const T *input_x, | ||||
| const T *input_y, const T *weight, T *loss, T *tmp_loss) { | const T *input_y, const T *weight, T *loss, T *tmp_loss) { | ||||
| T epsilon = 1e-6; | |||||
| T epsilon = 1e-12; | |||||
| if (reduction == 0) { | if (reduction == 0) { | ||||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { | ||||
| T antilogarithm = max(input_x[i], epsilon); | |||||
| T antilogarithm2 = min(1 - input_x[i], 1 - epsilon); | |||||
| T value = -weight[i] * (input_y[i] * logf(antilogarithm) + (1 - input_y[i]) * logf(antilogarithm2)); | |||||
| T value = | |||||
| -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); | |||||
| loss[i] = value; | loss[i] = value; | ||||
| } | } | ||||
| } else { | } else { | ||||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { | ||||
| T antilogarithm = max(input_x[i], epsilon); | |||||
| T antilogarithm2 = min(1 - input_x[i], 1 - epsilon); | |||||
| T value = -weight[i] * (input_y[i] * logf(antilogarithm) + (1 - input_y[i]) * logf(antilogarithm2)); | |||||
| T value = | |||||
| -weight[i] * (input_y[i] * logf(input_x[i] + epsilon) + (1 - input_y[i]) * logf(1 - input_x[i] + epsilon)); | |||||
| tmp_loss[i] = value; | tmp_loss[i] = value; | ||||
| } | } | ||||
| } | } | ||||
| @@ -139,12 +132,8 @@ __global__ void BinaryCrossEntropyLossKernel(const int input_size, const int red | |||||
| template <typename T> | template <typename T> | ||||
| void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, | void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, | ||||
| const T *weight, T *loss, cudaStream_t stream) { | |||||
| const T *weight, T *loss, T *tmp_loss, cudaStream_t stream) { | |||||
| LossInitKernel<<<1, 1, 0, stream>>>(loss); | LossInitKernel<<<1, 1, 0, stream>>>(loss); | ||||
| T *tmp_loss; | |||||
| if (reduction != 0) { | |||||
| cudaMalloc(reinterpret_cast<void **>(&tmp_loss), input_size * sizeof(T)); | |||||
| } | |||||
| BinaryCrossEntropyLossKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, reduction, input_x, | BinaryCrossEntropyLossKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, reduction, input_x, | ||||
| input_y, weight, loss, tmp_loss); | input_y, weight, loss, tmp_loss); | ||||
| if (reduction != 0) { | if (reduction != 0) { | ||||
| @@ -159,13 +148,12 @@ void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const T | |||||
| } | } | ||||
| Copy<<<1, 1, 0, stream>>>(loss, tmp_loss, reduction, input_size); | Copy<<<1, 1, 0, stream>>>(loss, tmp_loss, reduction, input_size); | ||||
| } | } | ||||
| cudaFree(tmp_loss); | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| __global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const int reduction, const T *input_x, | __global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const int reduction, const T *input_x, | ||||
| const T *input_y, const T *weight, const T *dloss, T *dx) { | const T *input_y, const T *weight, const T *dloss, T *dx) { | ||||
| T epsilon = 1e-6; | |||||
| T epsilon = 1e-12; | |||||
| if (reduction == 0) { | if (reduction == 0) { | ||||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { | ||||
| T denominator = max(input_x[i] * (1 - input_x[i]), epsilon); | T denominator = max(input_x[i] * (1 - input_x[i]), epsilon); | ||||
| @@ -193,13 +181,14 @@ void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, con | |||||
| } | } | ||||
| template void KLDivLoss(const int &input_size, const int &reduction, const float *input_x, const float *input_y, | template void KLDivLoss(const int &input_size, const int &reduction, const float *input_x, const float *input_y, | ||||
| float *loss, cudaStream_t stream); | |||||
| float *loss, float *tmp_loss, cudaStream_t stream); | |||||
| template void KLDivLossGrad(const int &input_size, const int &reduction, const float *input_x, const float *input_y, | template void KLDivLossGrad(const int &input_size, const int &reduction, const float *input_x, const float *input_y, | ||||
| const float *dloss, float *dx, float *dy, cudaStream_t stream); | const float *dloss, float *dx, float *dy, cudaStream_t stream); | ||||
| template void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const float *input_x, | template void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const float *input_x, | ||||
| const float *input_y, const float *weight, float *loss, cudaStream_t stream); | |||||
| const float *input_y, const float *weight, float *loss, float *tmp_loss, | |||||
| cudaStream_t stream); | |||||
| template void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, const float *input_x, | template void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, const float *input_x, | ||||
| const float *input_y, const float *weight, const float *dloss, float *dx, | const float *input_y, const float *weight, const float *dloss, float *dx, | ||||
| @@ -18,12 +18,12 @@ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_LOSS_WITH_REDUCTION_IMPL_CUH | #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_LOSS_WITH_REDUCTION_IMPL_CUH | ||||
| template <typename T> | template <typename T> | ||||
| void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, | void BinaryCrossEntropyLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, | ||||
| const T *weight, T *loss, cudaStream_t stream); | |||||
| const T *weight, T *loss, T *tmp_loss, cudaStream_t stream); | |||||
| template <typename T> | template <typename T> | ||||
| void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y, | void BinaryCrossEntropyLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y, | ||||
| const T *weight, const T *dloss, T *dx, cudaStream_t stream); | const T *weight, const T *dloss, T *dx, cudaStream_t stream); | ||||
| template <typename T> | template <typename T> | ||||
| void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, T *loss, | |||||
| void KLDivLoss(const int &input_size, const int &reduction, const T *input_x, const T *input_y, T *loss, T *tmp_loss, | |||||
| cudaStream_t stream); | cudaStream_t stream); | ||||
| template <typename T> | template <typename T> | ||||
| void KLDivLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y, const T *dloss, | void KLDivLossGrad(const int &input_size, const int &reduction, const T *input_x, const T *input_y, const T *dloss, | ||||
| @@ -30,19 +30,17 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel { | |||||
| public: | public: | ||||
| BinaryCrossEntropyGpuKernel() : input_size_(1), reduction_(1) {} | BinaryCrossEntropyGpuKernel() : input_size_(1), reduction_(1) {} | ||||
| ~BinaryCrossEntropyGpuKernel() override = default; | ~BinaryCrossEntropyGpuKernel() override = default; | ||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | ||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | ||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | ||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | ||||
| T *input_x = GetDeviceAddress<T>(inputs, 0); | T *input_x = GetDeviceAddress<T>(inputs, 0); | ||||
| T *input_y = GetDeviceAddress<T>(inputs, 1); | T *input_y = GetDeviceAddress<T>(inputs, 1); | ||||
| T *weight = GetDeviceAddress<T>(inputs, 2); | T *weight = GetDeviceAddress<T>(inputs, 2); | ||||
| T *loss = GetDeviceAddress<T>(outputs, 0); | T *loss = GetDeviceAddress<T>(outputs, 0); | ||||
| BinaryCrossEntropyLoss(input_size_, reduction_, input_x, input_y, weight, loss, | |||||
| T *tmp_loss = GetDeviceAddress<T>(workspace, 0); | |||||
| BinaryCrossEntropyLoss(input_size_, reduction_, input_x, input_y, weight, loss, tmp_loss, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -52,13 +50,16 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel { | |||||
| for (size_t i = 0; i < input_shape.size(); i++) { | for (size_t i = 0; i < input_shape.size(); i++) { | ||||
| input_size_ *= input_shape[i]; | input_size_ *= input_shape[i]; | ||||
| } | } | ||||
| string reduction = GetAttr<string>(kernel_node, "reduction"); | string reduction = GetAttr<string>(kernel_node, "reduction"); | ||||
| if (reduction == "none") { | if (reduction == "none") { | ||||
| reduction_ = 0; | reduction_ = 0; | ||||
| } else if (reduction == "sum") { | } else if (reduction == "sum") { | ||||
| reduction_ = 2; | reduction_ = 2; | ||||
| } | } | ||||
| workspace_size_ = sizeof(T); | |||||
| if (reduction_ == 0) { | |||||
| workspace_size_ *= input_size_; | |||||
| } | |||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -73,12 +74,13 @@ class BinaryCrossEntropyGpuKernel : public GpuKernel { | |||||
| } else { | } else { | ||||
| output_size_list_.push_back(sizeof(T)); | output_size_list_.push_back(sizeof(T)); | ||||
| } | } | ||||
| workspace_size_list_.push_back(workspace_size_); | |||||
| } | } | ||||
| private: | private: | ||||
| size_t input_size_; | size_t input_size_; | ||||
| int reduction_; | int reduction_; | ||||
| size_t workspace_size_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| std::vector<size_t> workspace_size_list_; | std::vector<size_t> workspace_size_list_; | ||||
| @@ -35,13 +35,13 @@ class KLDivLossGpuKernel : public GpuKernel { | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | ||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | ||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | ||||
| T *input_x = GetDeviceAddress<T>(inputs, 0); | T *input_x = GetDeviceAddress<T>(inputs, 0); | ||||
| T *input_y = GetDeviceAddress<T>(inputs, 1); | T *input_y = GetDeviceAddress<T>(inputs, 1); | ||||
| T *loss = GetDeviceAddress<T>(outputs, 0); | T *loss = GetDeviceAddress<T>(outputs, 0); | ||||
| KLDivLoss(input_size_, reduction_, input_x, input_y, loss, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| T *tmp_loss = GetDeviceAddress<T>(workspace, 0); | |||||
| KLDivLoss(input_size_, reduction_, input_x, input_y, loss, tmp_loss, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -50,13 +50,16 @@ class KLDivLossGpuKernel : public GpuKernel { | |||||
| for (size_t i = 0; i < input_shape.size(); i++) { | for (size_t i = 0; i < input_shape.size(); i++) { | ||||
| input_size_ *= input_shape[i]; | input_size_ *= input_shape[i]; | ||||
| } | } | ||||
| string reduction = GetAttr<string>(kernel_node, "reduction"); | string reduction = GetAttr<string>(kernel_node, "reduction"); | ||||
| if (reduction == "none") { | if (reduction == "none") { | ||||
| reduction_ = 0; | reduction_ = 0; | ||||
| } else if (reduction == "sum") { | } else if (reduction == "sum") { | ||||
| reduction_ = 2; | reduction_ = 2; | ||||
| } | } | ||||
| workspace_size_ = sizeof(T); | |||||
| if (reduction_ == 0) { | |||||
| workspace_size_ *= input_size_; | |||||
| } | |||||
| InitSizeLists(); | InitSizeLists(); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -70,12 +73,13 @@ class KLDivLossGpuKernel : public GpuKernel { | |||||
| } else { | } else { | ||||
| output_size_list_.push_back(sizeof(T)); | output_size_list_.push_back(sizeof(T)); | ||||
| } | } | ||||
| workspace_size_list_.push_back(workspace_size_); | |||||
| } | } | ||||
| private: | private: | ||||
| size_t input_size_; | size_t input_size_; | ||||
| int reduction_; | int reduction_; | ||||
| size_t workspace_size_; | |||||
| std::vector<size_t> input_size_list_; | std::vector<size_t> input_size_list_; | ||||
| std::vector<size_t> output_size_list_; | std::vector<size_t> output_size_list_; | ||||
| std::vector<size_t> workspace_size_list_; | std::vector<size_t> workspace_size_list_; | ||||
| @@ -72,6 +72,7 @@ class Categorical(Distribution): | |||||
| dtype=mstype.int32, | dtype=mstype.int32, | ||||
| name="Categorical"): | name="Categorical"): | ||||
| param = dict(locals()) | param = dict(locals()) | ||||
| param['param_dict'] = {'probs': probs, 'logits': logits} | |||||
| valid_dtype = mstype.int_type | valid_dtype = mstype.int_type | ||||
| check_type(dtype, valid_dtype, "Categorical") | check_type(dtype, valid_dtype, "Categorical") | ||||
| super(Categorical, self).__init__(seed, dtype, name, param) | super(Categorical, self).__init__(seed, dtype, name, param) | ||||