From: @he-botao Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -0,0 +1,85 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <cmath> | |||||
| #include <string> | |||||
| #include <thread> | |||||
| #include "backend/kernel_compiler/cpu/elu_grad_cpu_kernel.h" | |||||
| #include "runtime/device/cpu/cpu_device_address.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| void EluGradCPUKernel::EluGrad(const T *input0, const T *input1, T *out, size_t start, size_t end) { | |||||
| const T alpha = static_cast<T>(1); | |||||
| for (size_t i = start; i < end; i++) { | |||||
| out[i] = (input1[i] < static_cast<T>(0)) ? input0[i] * (input1[i] + alpha) : input0[i]; | |||||
| } | |||||
| } | |||||
| void EluGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||||
| if (dtype_ != AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1)) { | |||||
| MS_LOG(EXCEPTION) << "Input0 and input1 must has the same data type"; | |||||
| } | |||||
| } | |||||
| bool EluGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | |||||
| if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat) { | |||||
| LaunchKernel<float>(inputs, outputs); | |||||
| } else if (dtype_ == kNumberTypeFloat16) { | |||||
| LaunchKernel<float16>(inputs, outputs); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support."; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| template <typename T> | |||||
| void EluGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||||
| T *input0 = reinterpret_cast<T *>(inputs[0]->addr); | |||||
| T *input1 = reinterpret_cast<T *>(inputs[1]->addr); | |||||
| T *output = reinterpret_cast<T *>(outputs[0]->addr); | |||||
| size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1; | |||||
| auto max_thread_num = std::thread::hardware_concurrency(); | |||||
| size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; | |||||
| MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; | |||||
| std::vector<std::thread> threads; | |||||
| if (thread_num < 1) { | |||||
| MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; | |||||
| return; | |||||
| } | |||||
| threads.reserve(thread_num); | |||||
| size_t start = 0; | |||||
| size_t once_compute_size = (lens + thread_num - 1) / thread_num; | |||||
| if (once_compute_size < 1) { | |||||
| MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; | |||||
| return; | |||||
| } | |||||
| while (start < lens) { | |||||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | |||||
| threads.emplace_back(std::thread(&EluGradCPUKernel::EluGrad<T>, this, input0, input1, output, start, end)); | |||||
| start += once_compute_size; | |||||
| } | |||||
| for (size_t i = 0; i < threads.size(); ++i) { | |||||
| threads[i].join(); | |||||
| } | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ELU_GRAD_CPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ELU_GRAD_CPU_KERNEL_H_ | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| class EluGradCPUKernel : public CPUKernel { | |||||
| public: | |||||
| EluGradCPUKernel() = default; | |||||
| ~EluGradCPUKernel() override = default; | |||||
| void InitKernel(const CNodePtr &kernel_node) override; | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs) override; | |||||
| template <typename T> | |||||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||||
| private: | |||||
| template <typename T> | |||||
| void EluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||||
| TypeId dtype_{kTypeUnknown}; | |||||
| }; | |||||
| MS_REG_CPU_KERNEL( | |||||
| EluGrad, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| EluGradCPUKernel); | |||||
| MS_REG_CPU_KERNEL( | |||||
| EluGrad, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| EluGradCPUKernel); | |||||
| MS_REG_CPU_KERNEL( | |||||
| EluGrad, KernelAttr().AddInputAttr(kNumberTypeFloat).AddInputAttr(kNumberTypeFloat).AddOutputAttr(kNumberTypeFloat), | |||||
| EluGradCPUKernel); | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ELU_GRAD_CPU_KERNEL_H_ | |||||
| @@ -555,7 +555,7 @@ class Elu(PrimitiveWithInfer): | |||||
| Tensor, has the same shape and data type as `input_x`. | Tensor, has the same shape and data type as `input_x`. | ||||
| Supported Platforms: | Supported Platforms: | ||||
| ``Ascend`` ``GPU`` | |||||
| ``Ascend`` ``GPU`` ``CPU`` | |||||
| Examples: | Examples: | ||||
| >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32) | >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32) | ||||
| @@ -0,0 +1,75 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops.operations import _grad_ops as G | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||||
| class NetEluGrad(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetEluGrad, self).__init__() | |||||
| self.elu_grad = G.EluGrad() | |||||
| def construct(self, dy, y): | |||||
| return self.elu_grad(dy, y) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_elu_grad_fp32(): | |||||
| y = Tensor(np.array([[[[-0.3, 1, 2], | |||||
| [1, -0.6, 1], | |||||
| [2, 1, -2]]]]).astype(np.float32)) | |||||
| dy = Tensor(np.array([[[[-11, 2, 4], | |||||
| [-1, 1, -1], | |||||
| [-4, 4, -4]]]]).astype(np.float32)) | |||||
| expect = np.array([[[[-7.7, 2, 4], | |||||
| [-1, 0.4, -1], | |||||
| [-4, 4, 4]]]]).astype(np.float32) | |||||
| error = np.ones(shape=[1, 1, 3, 3]) * 1.0e-6 | |||||
| elu_grad = NetEluGrad() | |||||
| output = elu_grad(dy, y) | |||||
| print(output) | |||||
| diff = np.abs(output.asnumpy() - expect) | |||||
| double_check = diff / expect | |||||
| assert np.all(double_check < error) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_elu_grad_fp16(): | |||||
| y = Tensor(np.array([[0.5, 2, 5.5], [4.5, -2, 0]]).astype(np.float16)) | |||||
| dy = Tensor(np.array([[2, 1, 1.5], [-0.5, -1, -3]]).astype(np.float16)) | |||||
| expect = np.array([[2, 1, 1.5], [-0.5, 1, -3]]).astype(np.float16) | |||||
| error = np.ones(shape=[2, 3]) * 1.0e-3 | |||||
| elu_grad = NetEluGrad() | |||||
| output = elu_grad(dy, y) | |||||
| print(output) | |||||
| diff = np.abs(output.asnumpy() - expect) | |||||
| double_check = diff / expect | |||||
| assert np.all(double_check < error) | |||||