diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/elu_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/elu_grad_cpu_kernel.cc new file mode 100644 index 0000000000..70476811d2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/elu_grad_cpu_kernel.cc @@ -0,0 +1,85 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "backend/kernel_compiler/cpu/elu_grad_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +template +void EluGradCPUKernel::EluGrad(const T *input0, const T *input1, T *out, size_t start, size_t end) { + const T alpha = static_cast(1); + for (size_t i = start; i < end; i++) { + out[i] = (input1[i] < static_cast(0)) ? input0[i] * (input1[i] + alpha) : input0[i]; + } +} + +void EluGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); + if (dtype_ != AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1)) { + MS_LOG(EXCEPTION) << "Input0 and input1 must has the same data type"; + } +} + +bool EluGradCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs); + } else { + MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support."; + } + return true; +} + +template +void EluGradCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + T *input0 = reinterpret_cast(inputs[0]->addr); + T *input1 = reinterpret_cast(inputs[1]->addr); + T *output = reinterpret_cast(outputs[0]->addr); + + size_t lens = outputs[0]->size > 0 ? static_cast(outputs[0]->size / sizeof(T)) : 1; + auto max_thread_num = std::thread::hardware_concurrency(); + size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; + MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; + std::vector threads; + if (thread_num < 1) { + MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; + return; + } + threads.reserve(thread_num); + size_t start = 0; + size_t once_compute_size = (lens + thread_num - 1) / thread_num; + if (once_compute_size < 1) { + MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; + return; + } + while (start < lens) { + size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); + threads.emplace_back(std::thread(&EluGradCPUKernel::EluGrad, this, input0, input1, output, start, end)); + start += once_compute_size; + } + for (size_t i = 0; i < threads.size(); ++i) { + threads[i].join(); + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/elu_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/elu_grad_cpu_kernel.h new file mode 100644 index 0000000000..07629f33ea --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/elu_grad_cpu_kernel.h @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ELU_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ELU_GRAD_CPU_KERNEL_H_ +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class EluGradCPUKernel : public CPUKernel { + public: + EluGradCPUKernel() = default; + ~EluGradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + private: + template + void EluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); + TypeId dtype_{kTypeUnknown}; +}; + +MS_REG_CPU_KERNEL( + EluGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + EluGradCPUKernel); +MS_REG_CPU_KERNEL( + EluGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + EluGradCPUKernel); +MS_REG_CPU_KERNEL( + EluGrad, KernelAttr().AddInputAttr(kNumberTypeFloat).AddInputAttr(kNumberTypeFloat).AddOutputAttr(kNumberTypeFloat), + EluGradCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ELU_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 60f140b328..80844be63d 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -555,7 +555,7 @@ class Elu(PrimitiveWithInfer): Tensor, has the same shape and data type as `input_x`. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32) diff --git a/tests/st/ops/cpu/test_elu_grad_op.py b/tests/st/ops/cpu/test_elu_grad_op.py new file mode 100644 index 0000000000..2aca961a9b --- /dev/null +++ b/tests/st/ops/cpu/test_elu_grad_op.py @@ -0,0 +1,75 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + +context.set_context(mode=context.GRAPH_MODE, device_target='CPU') + + +class NetEluGrad(nn.Cell): + def __init__(self): + super(NetEluGrad, self).__init__() + self.elu_grad = G.EluGrad() + + def construct(self, dy, y): + return self.elu_grad(dy, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_elu_grad_fp32(): + y = Tensor(np.array([[[[-0.3, 1, 2], + [1, -0.6, 1], + [2, 1, -2]]]]).astype(np.float32)) + dy = Tensor(np.array([[[[-11, 2, 4], + [-1, 1, -1], + [-4, 4, -4]]]]).astype(np.float32)) + + expect = np.array([[[[-7.7, 2, 4], + [-1, 0.4, -1], + [-4, 4, 4]]]]).astype(np.float32) + + error = np.ones(shape=[1, 1, 3, 3]) * 1.0e-6 + + elu_grad = NetEluGrad() + output = elu_grad(dy, y) + print(output) + diff = np.abs(output.asnumpy() - expect) + double_check = diff / expect + assert np.all(double_check < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_elu_grad_fp16(): + y = Tensor(np.array([[0.5, 2, 5.5], [4.5, -2, 0]]).astype(np.float16)) + dy = Tensor(np.array([[2, 1, 1.5], [-0.5, -1, -3]]).astype(np.float16)) + expect = np.array([[2, 1, 1.5], [-0.5, 1, -3]]).astype(np.float16) + error = np.ones(shape=[2, 3]) * 1.0e-3 + + elu_grad = NetEluGrad() + output = elu_grad(dy, y) + print(output) + diff = np.abs(output.asnumpy() - expect) + double_check = diff / expect + assert np.all(double_check < error)