From: @he-botao Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -0,0 +1,85 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <cmath> | |||
| #include <string> | |||
| #include <thread> | |||
| #include "backend/kernel_compiler/cpu/elu_grad_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| void EluGradCPUKernel::EluGrad(const T *input0, const T *input1, T *out, size_t start, size_t end) { | |||
| const T alpha = static_cast<T>(1); | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = (input1[i] < static_cast<T>(0)) ? input0[i] * (input1[i] + alpha) : input0[i]; | |||
| } | |||
| } | |||
| void EluGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| if (dtype_ != AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1)) { | |||
| MS_LOG(EXCEPTION) << "Input0 and input1 must has the same data type"; | |||
| } | |||
| } | |||
| bool EluGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat) { | |||
| LaunchKernel<float>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat16) { | |||
| LaunchKernel<float16>(inputs, outputs); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support."; | |||
| } | |||
| return true; | |||
| } | |||
| template <typename T> | |||
| void EluGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||
| T *input0 = reinterpret_cast<T *>(inputs[0]->addr); | |||
| T *input1 = reinterpret_cast<T *>(inputs[1]->addr); | |||
| T *output = reinterpret_cast<T *>(outputs[0]->addr); | |||
| size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1; | |||
| auto max_thread_num = std::thread::hardware_concurrency(); | |||
| size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; | |||
| MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; | |||
| std::vector<std::thread> threads; | |||
| if (thread_num < 1) { | |||
| MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; | |||
| return; | |||
| } | |||
| threads.reserve(thread_num); | |||
| size_t start = 0; | |||
| size_t once_compute_size = (lens + thread_num - 1) / thread_num; | |||
| if (once_compute_size < 1) { | |||
| MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; | |||
| return; | |||
| } | |||
| while (start < lens) { | |||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | |||
| threads.emplace_back(std::thread(&EluGradCPUKernel::EluGrad<T>, this, input0, input1, output, start, end)); | |||
| start += once_compute_size; | |||
| } | |||
| for (size_t i = 0; i < threads.size(); ++i) { | |||
| threads[i].join(); | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ELU_GRAD_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ELU_GRAD_CPU_KERNEL_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class EluGradCPUKernel : public CPUKernel { | |||
| public: | |||
| EluGradCPUKernel() = default; | |||
| ~EluGradCPUKernel() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| template <typename T> | |||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| private: | |||
| template <typename T> | |||
| void EluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| TypeId dtype_{kTypeUnknown}; | |||
| }; | |||
| MS_REG_CPU_KERNEL( | |||
| EluGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| EluGradCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| EluGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| EluGradCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| EluGrad, KernelAttr().AddInputAttr(kNumberTypeFloat).AddInputAttr(kNumberTypeFloat).AddOutputAttr(kNumberTypeFloat), | |||
| EluGradCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ELU_GRAD_CPU_KERNEL_H_ | |||
| @@ -555,7 +555,7 @@ class Elu(PrimitiveWithInfer): | |||
| Tensor, has the same shape and data type as `input_x`. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32) | |||
| @@ -0,0 +1,75 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| class NetEluGrad(nn.Cell): | |||
| def __init__(self): | |||
| super(NetEluGrad, self).__init__() | |||
| self.elu_grad = G.EluGrad() | |||
| def construct(self, dy, y): | |||
| return self.elu_grad(dy, y) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_elu_grad_fp32(): | |||
| y = Tensor(np.array([[[[-0.3, 1, 2], | |||
| [1, -0.6, 1], | |||
| [2, 1, -2]]]]).astype(np.float32)) | |||
| dy = Tensor(np.array([[[[-11, 2, 4], | |||
| [-1, 1, -1], | |||
| [-4, 4, -4]]]]).astype(np.float32)) | |||
| expect = np.array([[[[-7.7, 2, 4], | |||
| [-1, 0.4, -1], | |||
| [-4, 4, 4]]]]).astype(np.float32) | |||
| error = np.ones(shape=[1, 1, 3, 3]) * 1.0e-6 | |||
| elu_grad = NetEluGrad() | |||
| output = elu_grad(dy, y) | |||
| print(output) | |||
| diff = np.abs(output.asnumpy() - expect) | |||
| double_check = diff / expect | |||
| assert np.all(double_check < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_elu_grad_fp16(): | |||
| y = Tensor(np.array([[0.5, 2, 5.5], [4.5, -2, 0]]).astype(np.float16)) | |||
| dy = Tensor(np.array([[2, 1, 1.5], [-0.5, -1, -3]]).astype(np.float16)) | |||
| expect = np.array([[2, 1, 1.5], [-0.5, 1, -3]]).astype(np.float16) | |||
| error = np.ones(shape=[2, 3]) * 1.0e-3 | |||
| elu_grad = NetEluGrad() | |||
| output = elu_grad(dy, y) | |||
| print(output) | |||
| diff = np.abs(output.asnumpy() - expect) | |||
| double_check = diff / expect | |||
| assert np.all(double_check < error) | |||