From: @TFbunny Reviewed-by: @tom__chen,@robingrosman Signed-off-by: @robingrosmanpull/14508/MERGE
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -21,10 +21,10 @@ __global__ void SigmoidCrossEntropyWithLogitsGradKernel(const size_t size, const | |||||
| const T *dout_addr, T *outputs) { | const T *dout_addr, T *outputs) { | ||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { | ||||
| if (logits[i] >= 0) { | if (logits[i] >= 0) { | ||||
| outputs[i] = (1. / (1. + exp(-logits[i])) - labels[i]) * dout_addr[i]; | |||||
| outputs[i] = (static_cast<T>(1.) / (static_cast<T>(1.) + exp(-logits[i])) - labels[i]) * dout_addr[i]; | |||||
| } else { | } else { | ||||
| const T exp_val = exp(logits[i]); | const T exp_val = exp(logits[i]); | ||||
| outputs[i] = (exp_val / (1. + exp_val) - labels[i]) * dout_addr[i]; | |||||
| outputs[i] = (exp_val / (static_cast<T>(1.) + exp_val) - labels[i]) * dout_addr[i]; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -39,3 +39,6 @@ void SigmoidCrossEntropyWithLogitsGrad(const size_t size, const T *logits, const | |||||
| template void SigmoidCrossEntropyWithLogitsGrad<float, float>(const size_t size, const float *logits, | template void SigmoidCrossEntropyWithLogitsGrad<float, float>(const size_t size, const float *logits, | ||||
| const float *labels, const float *dout_addr, | const float *labels, const float *dout_addr, | ||||
| float *outputs, cudaStream_t cuda_stream); | float *outputs, cudaStream_t cuda_stream); | ||||
| template void SigmoidCrossEntropyWithLogitsGrad<double, double>(const size_t size, const double *logits, | |||||
| const double *labels, const double *dout_addr, | |||||
| double *outputs, cudaStream_t cuda_stream); | |||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -20,7 +20,8 @@ template <typename T, typename S> | |||||
| __global__ void SigmoidCrossEntropyWithLogitsKernel(const size_t size, const T *logits, const S *labels, T *outputs) { | __global__ void SigmoidCrossEntropyWithLogitsKernel(const size_t size, const T *logits, const S *labels, T *outputs) { | ||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { | for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { | ||||
| const T reverse_factor = static_cast<T>(logits[i] >= 0); | const T reverse_factor = static_cast<T>(logits[i] >= 0); | ||||
| outputs[i] = log1p(exp(logits[i] - 2 * reverse_factor * logits[i])) - logits[i] * (labels[i] - reverse_factor); | |||||
| outputs[i] = | |||||
| log1p(exp(logits[i] - static_cast<T>(2) * reverse_factor * logits[i])) - logits[i] * (labels[i] - reverse_factor); | |||||
| } | } | ||||
| } | } | ||||
| @@ -32,3 +33,6 @@ void SigmoidCrossEntropyWithLogits(const size_t size, const T *logits, const S * | |||||
| template void SigmoidCrossEntropyWithLogits<float, float>(const size_t size, const float *logits, const float *labels, | template void SigmoidCrossEntropyWithLogits<float, float>(const size_t size, const float *logits, const float *labels, | ||||
| float *outputs, cudaStream_t cuda_stream); | float *outputs, cudaStream_t cuda_stream); | ||||
| template void SigmoidCrossEntropyWithLogits<double, double>(const size_t size, const double *logits, | |||||
| const double *labels, double *outputs, | |||||
| cudaStream_t cuda_stream); | |||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -22,5 +22,9 @@ MS_REG_GPU_KERNEL_TWO( | |||||
| SigmoidCrossEntropyWithLogits, | SigmoidCrossEntropyWithLogits, | ||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| SigmoidCrossEntropyWithLogitsGpuKernel, float, float) | SigmoidCrossEntropyWithLogitsGpuKernel, float, float) | ||||
| MS_REG_GPU_KERNEL_TWO( | |||||
| SigmoidCrossEntropyWithLogits, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||||
| SigmoidCrossEntropyWithLogitsGpuKernel, double, double) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -25,5 +25,12 @@ MS_REG_GPU_KERNEL_TWO(SigmoidCrossEntropyWithLogitsGrad, | |||||
| .AddInputAttr(kNumberTypeFloat32) | .AddInputAttr(kNumberTypeFloat32) | ||||
| .AddOutputAttr(kNumberTypeFloat32), | .AddOutputAttr(kNumberTypeFloat32), | ||||
| SigmoidCrossEntropyWithLogitsGradGpuKernel, float, float) | SigmoidCrossEntropyWithLogitsGradGpuKernel, float, float) | ||||
| MS_REG_GPU_KERNEL_TWO(SigmoidCrossEntropyWithLogitsGrad, | |||||
| KernelAttr() | |||||
| .AddInputAttr(kNumberTypeFloat64) | |||||
| .AddInputAttr(kNumberTypeFloat64) | |||||
| .AddInputAttr(kNumberTypeFloat64) | |||||
| .AddOutputAttr(kNumberTypeFloat64), | |||||
| SigmoidCrossEntropyWithLogitsGradGpuKernel, double, double) | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -31,32 +31,43 @@ class NetSigmoidCrossEntropyWithLogits(nn.Cell): | |||||
| return self.sigmoid_cross_entropy_with_logits_grad(logits, labels, dout) | return self.sigmoid_cross_entropy_with_logits_grad(logits, labels, dout) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_sigmoid_cross_entropy_with_logits(): | |||||
| def sigmoid_cross_entropy_with_logits_grad(nptype): | |||||
| logits = Tensor(np.array([[1, 1, 2], | logits = Tensor(np.array([[1, 1, 2], | ||||
| [1, 2, 1], | [1, 2, 1], | ||||
| [2, 1, 1]]).astype(np.float32)) | |||||
| [2, 1, 1]]).astype(nptype)) | |||||
| labels = Tensor(np.array([[0, 0, 1], | labels = Tensor(np.array([[0, 0, 1], | ||||
| [0, 1, 0], | [0, 1, 0], | ||||
| [1, 0, 0]]).astype(np.float32)) | |||||
| dout = Tensor(np.ones(shape=[3, 3]).astype(np.float32)) | |||||
| [1, 0, 0]]).astype(nptype)) | |||||
| dout = Tensor(np.ones(shape=[3, 3]).astype(nptype)) | |||||
| expect = np.array([[0.731059, 0.731059, -0.119203], | expect = np.array([[0.731059, 0.731059, -0.119203], | ||||
| [0.731059, -0.119203, 0.731059], | [0.731059, -0.119203, 0.731059], | ||||
| [-0.119203, 0.731059, 0.731059]]).astype(np.float32) | |||||
| [-0.119203, 0.731059, 0.731059]]).astype(nptype) | |||||
| error = np.ones(shape=[3, 3]) * 1.0e-6 | error = np.ones(shape=[3, 3]) * 1.0e-6 | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | ||||
| sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() | |||||
| output = sigmoid_cross_entropy_with_logits(logits, labels, dout) | |||||
| net = NetSigmoidCrossEntropyWithLogits() | |||||
| output = net(logits, labels, dout) | |||||
| diff = output.asnumpy() - expect | diff = output.asnumpy() - expect | ||||
| assert np.all(abs(diff) < error) | assert np.all(abs(diff) < error) | ||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | ||||
| sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() | |||||
| output = sigmoid_cross_entropy_with_logits(logits, labels, dout) | |||||
| net = NetSigmoidCrossEntropyWithLogits() | |||||
| output = net(logits, labels, dout) | |||||
| diff = output.asnumpy() - expect | diff = output.asnumpy() - expect | ||||
| assert np.all(abs(diff) < error) | assert np.all(abs(diff) < error) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_sigmoid_cross_entropy_with_logits_float32(): | |||||
| sigmoid_cross_entropy_with_logits_grad(np.float32) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_sigmoid_cross_entropy_with_logits_float64(): | |||||
| sigmoid_cross_entropy_with_logits_grad(np.float64) | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -31,30 +31,41 @@ class NetSigmoidCrossEntropyWithLogits(nn.Cell): | |||||
| return self.loss(logits, labels) | return self.loss(logits, labels) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_sigmoid_cross_entropy_with_logits(): | |||||
| def sigmoid_cross_entropy_with_logits(nptype): | |||||
| logits = Tensor(np.array([[1, 1, 2], | logits = Tensor(np.array([[1, 1, 2], | ||||
| [1, 2, 1], | [1, 2, 1], | ||||
| [2, 1, 1]]).astype(np.float32)) | |||||
| [2, 1, 1]]).astype(nptype)) | |||||
| labels = Tensor(np.array([[0, 0, 1], | labels = Tensor(np.array([[0, 0, 1], | ||||
| [0, 1, 0], | [0, 1, 0], | ||||
| [1, 0, 0]]).astype(np.float32)) | |||||
| [1, 0, 0]]).astype(nptype)) | |||||
| expect_loss = np.array([[1.313262, 1.313262, 0.126928], | expect_loss = np.array([[1.313262, 1.313262, 0.126928], | ||||
| [1.313262, 0.126928, 1.313262], | [1.313262, 0.126928, 1.313262], | ||||
| [0.126928, 1.313262, 1.313262]]).astype(np.float32) | |||||
| [0.126928, 1.313262, 1.313262]]).astype(nptype) | |||||
| error = np.ones(shape=[3, 3]) * 1.0e-6 | error = np.ones(shape=[3, 3]) * 1.0e-6 | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | ||||
| sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() | |||||
| output = sigmoid_cross_entropy_with_logits(logits, labels) | |||||
| net = NetSigmoidCrossEntropyWithLogits() | |||||
| output = net(logits, labels) | |||||
| diff = output.asnumpy() - expect_loss | diff = output.asnumpy() - expect_loss | ||||
| assert np.all(abs(diff) < error) | assert np.all(abs(diff) < error) | ||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | ||||
| sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() | |||||
| output = sigmoid_cross_entropy_with_logits(logits, labels) | |||||
| net = NetSigmoidCrossEntropyWithLogits() | |||||
| output = net(logits, labels) | |||||
| diff = output.asnumpy() - expect_loss | diff = output.asnumpy() - expect_loss | ||||
| assert np.all(abs(diff) < error) | assert np.all(abs(diff) < error) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_sigmoid_cross_entropy_with_logits_float32(): | |||||
| sigmoid_cross_entropy_with_logits(np.float32) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_sigmoid_cross_entropy_with_logits_float64(): | |||||
| sigmoid_cross_entropy_with_logits(np.float64) | |||||