| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh" | |||||
| template <typename T, typename S> | |||||
| __global__ void SigmoidCrossEntropyWithLogitsKernel(const size_t size, const T *logits, const S *labels, T *outputs) { | |||||
| for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { | |||||
| const T reverse_factor = static_cast<T>(logits[i] >= 0); | |||||
| outputs[i] = log1p(exp(logits[i] - 2 * reverse_factor * logits[i])) - logits[i] * (labels[i] - reverse_factor); | |||||
| } | |||||
| } | |||||
| template <typename T, typename S> | |||||
| void SigmoidCrossEntropyWithLogits(const size_t size, const T *logits, const S *labels, T *outputs, | |||||
| cudaStream_t cuda_stream) { | |||||
| SigmoidCrossEntropyWithLogitsKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, logits, labels, outputs); | |||||
| } | |||||
| template void SigmoidCrossEntropyWithLogits<float, float>(const size_t size, const float *logits, const float *labels, | |||||
| float *outputs, cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,25 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_IMPL_H_ | |||||
| #include "device/gpu/cuda_common.h" | |||||
| template <typename T, typename S> | |||||
| void SigmoidCrossEntropyWithLogits(const size_t size, const T *logits, const S *labels, T *outputs, | |||||
| cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_IMPL_H_ | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "kernel/gpu/nn/sigmoid_cross_entropy_with_logits_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_TWO( | |||||
| SigmoidCrossEntropyWithLogits, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| SigmoidCrossEntropyWithLogitsGpuKernel, float, float) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,97 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include "kernel/gpu/gpu_kernel.h" | |||||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||||
| #include "kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T, typename S> | |||||
| class SigmoidCrossEntropyWithLogitsGpuKernel : public GpuKernel { | |||||
| public: | |||||
| SigmoidCrossEntropyWithLogitsGpuKernel() : logits_size_(0), labels_size_(0), outputs_size_(0) {} | |||||
| ~SigmoidCrossEntropyWithLogitsGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| T *logits_addr = GetDeviceAddress<T>(inputs, 0); | |||||
| S *labels_addr = GetDeviceAddress<S>(inputs, 1); | |||||
| T *outputs_addr = GetDeviceAddress<T>(outputs, 0); | |||||
| SigmoidCrossEntropyWithLogits(inputs[0]->size / sizeof(T), logits_addr, labels_addr, outputs_addr, | |||||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 2) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but SigmoidCrossEntropyWithLogits needs 2 inputs."; | |||||
| return false; | |||||
| } | |||||
| logits_size_ = sizeof(T); | |||||
| labels_size_ = sizeof(S); | |||||
| outputs_size_ = sizeof(T); | |||||
| auto logits_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| for (size_t i = 0; i < logits_shape.size(); i++) { | |||||
| logits_size_ *= logits_shape[i]; | |||||
| } | |||||
| auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| for (size_t i = 0; i < labels_shape.size(); i++) { | |||||
| labels_size_ *= labels_shape[i]; | |||||
| } | |||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||||
| for (size_t i = 0; i < output_shape.size(); i++) { | |||||
| outputs_size_ *= output_shape[i]; | |||||
| } | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(logits_size_); | |||||
| input_size_list_.push_back(labels_size_); | |||||
| output_size_list_.push_back(outputs_size_); | |||||
| } | |||||
| private: | |||||
| size_t logits_size_; | |||||
| size_t labels_size_; | |||||
| size_t outputs_size_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GPU_KERNEL_H_ | |||||
| @@ -0,0 +1,60 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| class NetSigmoidCrossEntropyWithLogits(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetSigmoidCrossEntropyWithLogits, self).__init__() | |||||
| self.loss = P.SigmoidCrossEntropyWithLogits() | |||||
| def construct(self, logits, labels): | |||||
| return self.loss(logits, labels) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_sigmoid_cross_entropy_with_logits(): | |||||
| logits = Tensor(np.array([[1, 1, 2], | |||||
| [1, 2, 1], | |||||
| [2, 1, 1]]).astype(np.float32)) | |||||
| labels = Tensor(np.array([[0, 0, 1], | |||||
| [0, 1, 0], | |||||
| [1, 0, 0]]).astype(np.float32)) | |||||
| expect_loss = np.array([[1.313262, 1.313262, 0.126928], | |||||
| [1.313262, 0.126928, 1.313262], | |||||
| [0.126928, 1.313262, 1.313262]]).astype(np.float32) | |||||
| error = np.ones(shape=[3, 3]) * 1.0e-6 | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | |||||
| sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() | |||||
| output = sigmoid_cross_entropy_with_logits(logits, labels) | |||||
| diff = output.asnumpy() - expect_loss | |||||
| assert np.all(abs(diff) < error) | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') | |||||
| sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() | |||||
| output = sigmoid_cross_entropy_with_logits(logits, labels) | |||||
| diff = output.asnumpy() - expect_loss | |||||
| assert np.all(abs(diff) < error) | |||||