Merge pull request !7693 from wanyiming/add_sigmoid_cross_entropy_with_logit_cputags/v1.1.0
| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/sigmoid_cross_entropy_with_logits_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void SigmoidCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| CheckParam(kernel_node); | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| std::vector<uint64_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (const uint64_t &d : x_shape) { | |||
| tensor_size_ *= d; | |||
| } | |||
| } | |||
| bool SigmoidCrossEntropyWithLogitsCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (dtype_ == kNumberTypeFloat16) { | |||
| LaunchKernel<float16>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat32) { | |||
| LaunchKernel<float>(inputs, outputs); | |||
| } | |||
| return true; | |||
| } | |||
| template <typename T> | |||
| void SigmoidCrossEntropyWithLogitsCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| auto logits_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||
| auto labels_addr = reinterpret_cast<T *>(inputs[1]->addr); | |||
| auto output_addr = reinterpret_cast<T *>(outputs[0]->addr); | |||
| T zero = (T)0.0; | |||
| T one = (T)1.0; | |||
| T two = (T)2.0; | |||
| for (uint64_t i = 0; i < tensor_size_; ++i) { | |||
| if (logits_addr[i] >= zero) { | |||
| output_addr[i] = log1p(exp(logits_addr[i] - two * logits_addr[i])) - logits_addr[i] * (labels_addr[i] - one); | |||
| } else { | |||
| output_addr[i] = log1p(exp(logits_addr[i])) - logits_addr[i] * labels_addr[i]; | |||
| } | |||
| } | |||
| } | |||
| void SigmoidCrossEntropyWithLogitsCPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 2) { | |||
| MS_LOG(EXCEPTION) << "SigmoidCrossEntropyWithLogitsCPUKernel needs 2 inputs, but gets " << input_num; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(EXCEPTION) << "SigmoidCrossEntropyWithLogitsCPUKernel expects 1 output, but gets" << output_num; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,57 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class SigmoidCrossEntropyWithLogitsCPUKernel : public CPUKernel { | |||
| public: | |||
| SigmoidCrossEntropyWithLogitsCPUKernel() = default; | |||
| ~SigmoidCrossEntropyWithLogitsCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| template <typename T> | |||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| private: | |||
| void CheckParam(const CNodePtr &kernel_node); | |||
| TypeId dtype_{kTypeUnknown}; | |||
| uint64_t tensor_size_{1}; | |||
| }; | |||
| MS_REG_CPU_KERNEL( | |||
| SigmoidCrossEntropyWithLogits, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| SigmoidCrossEntropyWithLogitsCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| SigmoidCrossEntropyWithLogits, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SigmoidCrossEntropyWithLogitsCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ | |||
| @@ -0,0 +1,72 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/sigmoid_cross_entropy_with_logits_grad_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void SigmoidCrossEntropyWithLogitsGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| CheckParam(kernel_node); | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| std::vector<uint64_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (const uint64_t &d : x_shape) { | |||
| tensor_size_ *= d; | |||
| } | |||
| } | |||
| bool SigmoidCrossEntropyWithLogitsGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (dtype_ == kNumberTypeFloat16) { | |||
| LaunchKernel<float16>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat32) { | |||
| LaunchKernel<float>(inputs, outputs); | |||
| } | |||
| return true; | |||
| } | |||
| template <typename T> | |||
| void SigmoidCrossEntropyWithLogitsGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| auto logits_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||
| auto labels_addr = reinterpret_cast<T *>(inputs[1]->addr); | |||
| auto dloss_addr = reinterpret_cast<T *>(inputs[2]->addr); | |||
| auto output_addr = reinterpret_cast<T *>(outputs[0]->addr); | |||
| T zero = (T)0.0; | |||
| T one = (T)1.0; | |||
| for (uint64_t i = 0; i < tensor_size_; ++i) { | |||
| if (logits_addr[i] >= zero) { | |||
| output_addr[i] = (one / (one + exp(-logits_addr[i])) - labels_addr[i]) * dloss_addr[i]; | |||
| } else { | |||
| const T exp_val = exp(logits_addr[i]); | |||
| output_addr[i] = (exp_val / (one + exp_val) - labels_addr[i]) * dloss_addr[i]; | |||
| } | |||
| } | |||
| } | |||
| void SigmoidCrossEntropyWithLogitsGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 3) { | |||
| MS_LOG(EXCEPTION) << "SigmoidCrossEntropyWithLogitsCPUKernel needs 2 inputs, but gets " << input_num; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(EXCEPTION) << "SigmoidCrossEntropyWithLogitsCPUKernel expects 1 output, but gets" << output_num; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,63 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_CPU_KERNEL_H_ | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class SigmoidCrossEntropyWithLogitsGradCPUKernel : public CPUKernel { | |||
| public: | |||
| SigmoidCrossEntropyWithLogitsGradCPUKernel() = default; | |||
| ~SigmoidCrossEntropyWithLogitsGradCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| template <typename T> | |||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| private: | |||
| void CheckParam(const CNodePtr &kernel_node); | |||
| TypeId dtype_{kTypeUnknown}; | |||
| uint64_t tensor_size_{1}; | |||
| }; | |||
| MS_REG_CPU_KERNEL(SigmoidCrossEntropyWithLogitsGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddInputAttr(kNumberTypeFloat16) | |||
| .AddOutputAttr(kNumberTypeFloat16), | |||
| SigmoidCrossEntropyWithLogitsGradCPUKernel); | |||
| MS_REG_CPU_KERNEL(SigmoidCrossEntropyWithLogitsGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| SigmoidCrossEntropyWithLogitsGradCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_CPU_KERNEL_H_ | |||
| @@ -0,0 +1,56 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| class NetSigmoidCrossEntropyWithLogits(nn.Cell): | |||
| def __init__(self): | |||
| super(NetSigmoidCrossEntropyWithLogits, self).__init__() | |||
| self.sigmoid_cross_entropy_with_logits_grad = G.SigmoidCrossEntropyWithLogitsGrad() | |||
| def construct(self, logits, labels, dout): | |||
| return self.sigmoid_cross_entropy_with_logits_grad(logits, labels, dout) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_sigmoid_cross_entropy_with_logits(): | |||
| logits = Tensor(np.array([[1, 1, 2], | |||
| [1, 2, 1], | |||
| [2, 1, 1]]).astype(np.float32)) | |||
| labels = Tensor(np.array([[0, 0, 1], | |||
| [0, 1, 0], | |||
| [1, 0, 0]]).astype(np.float32)) | |||
| dout = Tensor(np.ones(shape=[3, 3]).astype(np.float32)) | |||
| expect = np.array([[0.731059, 0.731059, -0.119203], | |||
| [0.731059, -0.119203, 0.731059], | |||
| [-0.119203, 0.731059, 0.731059]]).astype(np.float32) | |||
| error = np.ones(shape=[3, 3]) * 1.0e-6 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() | |||
| output = sigmoid_cross_entropy_with_logits(logits, labels, dout) | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(abs(diff) < error) | |||
| @@ -0,0 +1,54 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| class NetSigmoidCrossEntropyWithLogits(nn.Cell): | |||
| def __init__(self): | |||
| super(NetSigmoidCrossEntropyWithLogits, self).__init__() | |||
| self.loss = P.SigmoidCrossEntropyWithLogits() | |||
| def construct(self, logits, labels): | |||
| return self.loss(logits, labels) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_sigmoid_cross_entropy_with_logits(): | |||
| logits = Tensor(np.array([[1, 1, 2], | |||
| [1, 2, 1], | |||
| [2, 1, 1]]).astype(np.float32)) | |||
| labels = Tensor(np.array([[0, 0, 1], | |||
| [0, 1, 0], | |||
| [1, 0, 0]]).astype(np.float32)) | |||
| expect_loss = np.array([[1.313262, 1.313262, 0.126928], | |||
| [1.313262, 0.126928, 1.313262], | |||
| [0.126928, 1.313262, 1.313262]]).astype(np.float32) | |||
| error = np.ones(shape=[3, 3]) * 1.0e-6 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| sigmoid_cross_entropy_with_logits = NetSigmoidCrossEntropyWithLogits() | |||
| output = sigmoid_cross_entropy_with_logits(logits, labels) | |||
| diff = output.asnumpy() - expect_loss | |||
| assert np.all(abs(diff) < error) | |||