From: @TFbunny Reviewed-by: @tom__chen,@robingrosman Signed-off-by: @robingrosmanpull/14508/MERGE
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -21,10 +21,10 @@ __global__ void SigmoidCrossEntropyWithLogitsGradKernel(const size_t size, const | |||
| const T *dout_addr, T *outputs) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { | |||
| if (logits[i] >= 0) { | |||
| outputs[i] = (1. / (1. + exp(-logits[i])) - labels[i]) * dout_addr[i]; | |||
| outputs[i] = (static_cast<T>(1.) / (static_cast<T>(1.) + exp(-logits[i])) - labels[i]) * dout_addr[i]; | |||
| } else { | |||
| const T exp_val = exp(logits[i]); | |||
| outputs[i] = (exp_val / (1. + exp_val) - labels[i]) * dout_addr[i]; | |||
| outputs[i] = (exp_val / (static_cast<T>(1.) + exp_val) - labels[i]) * dout_addr[i]; | |||
| } | |||
| } | |||
| } | |||
| @@ -39,3 +39,6 @@ void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const | |||
| template void SigmoidCrossEntropyWithLogitsGrad<float, float>(const size_t size, const float *logits, | |||
| const float *labels, const float *dout_addr, | |||
| float *outputs, cudaStream_t cuda_stream); | |||
| template void SigmoidCrossEntropyWithLogitsGrad<double, double>(const size_t size, const double *logits, | |||
| const double *labels, const double *dout_addr, | |||
| double *outputs, cudaStream_t cuda_stream); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -20,7 +20,8 @@ template <typename T, typename S> | |||
| __global__ void SigmoidCrossEntropyWithLogitsKernel(const size_t size, const T *logits, const S *labels, T *outputs) { | |||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { | |||
| const T reverse_factor = static_cast<T>(logits[i] >= 0); | |||
| outputs[i] = log1p(exp(logits[i] - 2 * reverse_factor * logits[i])) - logits[i] * (labels[i] - reverse_factor); | |||
| outputs[i] = | |||
| log1p(exp(logits[i] - static_cast<T>(2) * reverse_factor * logits[i])) - logits[i] * (labels[i] - reverse_factor); | |||
| } | |||
| } | |||
| @@ -32,3 +33,6 @@ void SigmoidCrossEntropyWithLogits(const size_t size, const T *logits, const S * | |||
| template void SigmoidCrossEntropyWithLogits<float, float>(const size_t size, const float *logits, const float *labels, | |||
| float *outputs, cudaStream_t cuda_stream); | |||
| template void SigmoidCrossEntropyWithLogits<double, double>(const size_t size, const double *logits, | |||
| const double *labels, double *outputs, | |||
| cudaStream_t cuda_stream); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -22,5 +22,9 @@ MS_REG_GPU_KERNEL_TWO( | |||
| SigmoidCrossEntropyWithLogits, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SigmoidCrossEntropyWithLogitsGpuKernel, float, float) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| SigmoidCrossEntropyWithLogits, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| SigmoidCrossEntropyWithLogitsGpuKernel, double, double) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -25,5 +25,12 @@ MS_REG_GPU_KERNEL_TWO(SigmoidCrossEntropyWithLogitsGrad, | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| SigmoidCrossEntropyWithLogitsGradGpuKernel, float, float) | |||
| MS_REG_GPU_KERNEL_TWO(SigmoidCrossEntropyWithLogitsGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddInputAttr(kNumberTypeFloat64) | |||
| .AddOutputAttr(kNumberTypeFloat64), | |||
| SigmoidCrossEntropyWithLogitsGradGpuKernel, double, double) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -31,32 +31,43 @@ class NetSigmoidCrossEntropyWithLogits(nn.Cell): | |||
| return self.sigmoid_cross_entropy_with_logits_grad(logits, labels, dout) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sigmoid_cross_entropy_with_logits(): | |||
| def sigmoid_cross_entropy_with_logits_grad(nptype): | |||
| logits = Tensor(np.array([[1, 1, 2], | |||
| [1, 2, 1], | |||
| [2, 1, 1]]).astype(np.float32)) | |||
| [2, 1, 1]]).astype(nptype)) | |||
| labels = Tensor(np.array([[0, 0, 1], | |||
| [0, 1, 0], | |||
| [1, 0, 0]]).astype(np.float32)) | |||
| dout = Tensor(np.ones(shape=[3, 3]).astype(np.float32)) | |||
| [1, 0, 0]]).astype(nptype)) | |||
| dout = Tensor(np.ones(shape=[3, 3]).astype(nptype)) | |||
| expect = np.array([[0.731059, 0.731059, -0.119203], | |||
| [0.731059, -0.119203, 0.731059], | |||
| [-0.119203, 0.731059, 0.731059]]).astype(np.float32) | |||
| [-0.119203, 0.731059, 0.731059]]).astype(nptype) | |||
| error = np.ones(shape=[3, 3]) * 1.0e-6 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() | |||
| output = sigmoid_cross_entropy_with_logits(logits, labels, dout) | |||
| net = NetSigmoidCrossEntropyWithLogits() | |||
| output = net(logits, labels, dout) | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(abs(diff) < error) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() | |||
| output = sigmoid_cross_entropy_with_logits(logits, labels, dout) | |||
| net = NetSigmoidCrossEntropyWithLogits() | |||
| output = net(logits, labels, dout) | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sigmoid_cross_entropy_with_logits_float32(): | |||
| sigmoid_cross_entropy_with_logits_grad(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sigmoid_cross_entropy_with_logits_float64(): | |||
| sigmoid_cross_entropy_with_logits_grad(np.float64) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -31,30 +31,41 @@ class NetSigmoidCrossEntropyWithLogits(nn.Cell): | |||
| return self.loss(logits, labels) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sigmoid_cross_entropy_with_logits(): | |||
| def sigmoid_cross_entropy_with_logits(nptype): | |||
| logits = Tensor(np.array([[1, 1, 2], | |||
| [1, 2, 1], | |||
| [2, 1, 1]]).astype(np.float32)) | |||
| [2, 1, 1]]).astype(nptype)) | |||
| labels = Tensor(np.array([[0, 0, 1], | |||
| [0, 1, 0], | |||
| [1, 0, 0]]).astype(np.float32)) | |||
| [1, 0, 0]]).astype(nptype)) | |||
| expect_loss = np.array([[1.313262, 1.313262, 0.126928], | |||
| [1.313262, 0.126928, 1.313262], | |||
| [0.126928, 1.313262, 1.313262]]).astype(np.float32) | |||
| [0.126928, 1.313262, 1.313262]]).astype(nptype) | |||
| error = np.ones(shape=[3, 3]) * 1.0e-6 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||
| sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() | |||
| output = sigmoid_cross_entropy_with_logits(logits, labels) | |||
| net = NetSigmoidCrossEntropyWithLogits() | |||
| output = net(logits, labels) | |||
| diff = output.asnumpy() - expect_loss | |||
| assert np.all(abs(diff) < error) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||
| sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() | |||
| output = sigmoid_cross_entropy_with_logits(logits, labels) | |||
| net = NetSigmoidCrossEntropyWithLogits() | |||
| output = net(logits, labels) | |||
| diff = output.asnumpy() - expect_loss | |||
| assert np.all(abs(diff) < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sigmoid_cross_entropy_with_logits_float32(): | |||
| sigmoid_cross_entropy_with_logits(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_sigmoid_cross_entropy_with_logits_float64(): | |||
| sigmoid_cross_entropy_with_logits(np.float64) | |||