| @@ -93,6 +93,20 @@ void Gelu(const T *in, T *out, size_t start, size_t end) { | |||||
| out[i] = x * ((T)1.0 + tanh_res) / (T)2.0; | out[i] = x * ((T)1.0 + tanh_res) / (T)2.0; | ||||
| } | } | ||||
| } | } | ||||
| template <typename T> | |||||
| void Asin(const T *in, T *out, size_t start, size_t end) { | |||||
| for (size_t i = start; i < end; i++) { | |||||
| out[i] = asin(in[i]); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void ACos(const T *in, T *out, size_t start, size_t end) { | |||||
| for (size_t i = start; i < end; i++) { | |||||
| out[i] = acos(in[i]); | |||||
| } | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| @@ -116,6 +130,10 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| operate_type_ = RECIPROCAL; | operate_type_ = RECIPROCAL; | ||||
| } else if (kernel_name == prim::kPrimGelu->name()) { | } else if (kernel_name == prim::kPrimGelu->name()) { | ||||
| operate_type_ = GELU; | operate_type_ = GELU; | ||||
| } else if (kernel_name == prim::kPrimAsin->name()) { | |||||
| operate_type_ = ASIN; | |||||
| } else if (kernel_name == prim::kPrimACos->name()) { | |||||
| operate_type_ = ACOS; | |||||
| } | } | ||||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | ||||
| target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0); | target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0); | ||||
| @@ -216,6 +234,10 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs | |||||
| threads.emplace_back(std::thread(Reciprocal<T>, input, output, start, end)); | threads.emplace_back(std::thread(Reciprocal<T>, input, output, start, end)); | ||||
| } else if (operate_type_ == GELU) { | } else if (operate_type_ == GELU) { | ||||
| threads.emplace_back(std::thread(Gelu<T>, input, output, start, end)); | threads.emplace_back(std::thread(Gelu<T>, input, output, start, end)); | ||||
| } else if (operate_type_ == ASIN) { | |||||
| threads.emplace_back(std::thread(Asin<T>, input, output, start, end)); | |||||
| } else if (operate_type_ == ACOS) { | |||||
| threads.emplace_back(std::thread(ACos<T>, input, output, start, end)); | |||||
| } | } | ||||
| start += once_compute_size; | start += once_compute_size; | ||||
| } | } | ||||
| @@ -70,6 +70,14 @@ MS_REG_CPU_KERNEL(Gelu, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputA | |||||
| ArithmeticSelfCPUKernel); | ArithmeticSelfCPUKernel); | ||||
| MS_REG_CPU_KERNEL(LogicalNot, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | MS_REG_CPU_KERNEL(LogicalNot, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | ||||
| ArithmeticSelfCPUKernel); | ArithmeticSelfCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ArithmeticSelfCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Asin, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| ArithmeticSelfCPUKernel); | |||||
| MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ArithmeticSelfCPUKernel); | |||||
| MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| ArithmeticSelfCPUKernel); | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -93,6 +93,10 @@ enum OperateType { | |||||
| RECIPROCAL, | RECIPROCAL, | ||||
| GELU, | GELU, | ||||
| GELUGRAD, | GELUGRAD, | ||||
| ASIN, | |||||
| ACOS, | |||||
| ASINGRAD, | |||||
| ACOSGRAD, | |||||
| }; | }; | ||||
| class CPUKernel : public kernel::KernelMod { | class CPUKernel : public kernel::KernelMod { | ||||
| @@ -90,6 +90,48 @@ void EltWiseGradCPUKernel::GeluGrad(const T *input1, const T *input2, T *out, si | |||||
| } | } | ||||
| } | } | ||||
| template <typename T> | |||||
| void EltWiseGradCPUKernel::AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||||
| for (size_t i = start; i < end; i++) { | |||||
| T dividend = input2[i]; | |||||
| T divisor = sqrt(1 - input1[i] * input1[i]); | |||||
| if (divisor == 0) { | |||||
| if (dividend == 0) { | |||||
| out[i] = std::numeric_limits<T>::quiet_NaN(); | |||||
| continue; | |||||
| } | |||||
| if (std::numeric_limits<T>::has_infinity) { | |||||
| out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||||
| } else { | |||||
| out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||||
| } | |||||
| continue; | |||||
| } | |||||
| out[i] = dividend / divisor; | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void EltWiseGradCPUKernel::ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||||
| for (size_t i = start; i < end; i++) { | |||||
| T dividend = -input2[i]; | |||||
| T divisor = sqrt(1 - input1[i] * input1[i]); | |||||
| if (divisor == 0) { | |||||
| if (dividend == 0) { | |||||
| out[i] = std::numeric_limits<T>::quiet_NaN(); | |||||
| continue; | |||||
| } | |||||
| if (std::numeric_limits<T>::has_infinity) { | |||||
| out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||||
| } else { | |||||
| out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||||
| } | |||||
| continue; | |||||
| } | |||||
| out[i] = dividend / divisor; | |||||
| } | |||||
| } | |||||
| void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | ||||
| @@ -107,6 +149,10 @@ void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| operate_type_ = SQRTGRAD; | operate_type_ = SQRTGRAD; | ||||
| } else if (kernel_name == "GeluGrad") { | } else if (kernel_name == "GeluGrad") { | ||||
| operate_type_ = GELUGRAD; | operate_type_ = GELUGRAD; | ||||
| } else if (kernel_name == "AsinGrad") { | |||||
| operate_type_ = ASINGRAD; | |||||
| } else if (kernel_name == "ACosGrad") { | |||||
| operate_type_ = ACOSGRAD; | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Not support " << kernel_name; | MS_LOG(EXCEPTION) << "Not support " << kernel_name; | ||||
| } | } | ||||
| @@ -188,6 +234,10 @@ void EltWiseGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, c | |||||
| threads.emplace_back(std::thread(&EltWiseGradCPUKernel::SqrtGrad<T>, this, input1, input2, output, start, end)); | threads.emplace_back(std::thread(&EltWiseGradCPUKernel::SqrtGrad<T>, this, input1, input2, output, start, end)); | ||||
| } else if (operate_type_ == GELUGRAD) { | } else if (operate_type_ == GELUGRAD) { | ||||
| threads.emplace_back(std::thread(&EltWiseGradCPUKernel::GeluGrad<T>, this, input1, input2, output, start, end)); | threads.emplace_back(std::thread(&EltWiseGradCPUKernel::GeluGrad<T>, this, input1, input2, output, start, end)); | ||||
| } else if (operate_type_ == ASINGRAD) { | |||||
| threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AsinGrad<T>, this, input1, input2, output, start, end)); | |||||
| } else if (operate_type_ == ACOSGRAD) { | |||||
| threads.emplace_back(std::thread(&EltWiseGradCPUKernel::ACosGrad<T>, this, input1, input2, output, start, end)); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Not support " << operate_type_; | MS_LOG(EXCEPTION) << "Not support " << operate_type_; | ||||
| } | } | ||||
| @@ -17,6 +17,7 @@ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ELTWISE_GRAD_CPU_KERNEL_H_ | #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ELTWISE_GRAD_CPU_KERNEL_H_ | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <limits> | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | #include "backend/kernel_compiler/cpu/cpu_kernel.h" | ||||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | ||||
| @@ -49,6 +50,10 @@ class EltWiseGradCPUKernel : public CPUKernel { | |||||
| void TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | void TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | ||||
| template <typename T> | template <typename T> | ||||
| void GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | void GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | ||||
| template <typename T> | |||||
| void AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||||
| template <typename T> | |||||
| void ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||||
| std::vector<size_t> input_shape0_; | std::vector<size_t> input_shape0_; | ||||
| std::vector<size_t> input_shape1_; | std::vector<size_t> input_shape1_; | ||||
| std::vector<size_t> input_element_num0_; | std::vector<size_t> input_element_num0_; | ||||
| @@ -90,6 +95,20 @@ MS_REG_CPU_KERNEL(GeluGrad, | |||||
| .AddInputAttr(kNumberTypeFloat32) | .AddInputAttr(kNumberTypeFloat32) | ||||
| .AddOutputAttr(kNumberTypeFloat32), | .AddOutputAttr(kNumberTypeFloat32), | ||||
| EltWiseGradCPUKernel); | EltWiseGradCPUKernel); | ||||
| MS_REG_CPU_KERNEL( | |||||
| AsinGrad, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| EltWiseGradCPUKernel); | |||||
| MS_REG_CPU_KERNEL( | |||||
| AsinGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| EltWiseGradCPUKernel); | |||||
| MS_REG_CPU_KERNEL( | |||||
| ACosGrad, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| EltWiseGradCPUKernel); | |||||
| MS_REG_CPU_KERNEL( | |||||
| ACosGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| EltWiseGradCPUKernel); | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -278,6 +278,10 @@ inline const PrimitivePtr kPrimSplitV = std::make_shared<Primitive>("SplitV"); | |||||
| inline const PrimitivePtr kPrimLinSpace = std::make_shared<Primitive>("LinSpace"); | inline const PrimitivePtr kPrimLinSpace = std::make_shared<Primitive>("LinSpace"); | ||||
| inline const PrimitivePtr kPrimSign = std::make_shared<Primitive>("Sign"); | inline const PrimitivePtr kPrimSign = std::make_shared<Primitive>("Sign"); | ||||
| inline const PrimitivePtr kPrimSquaredDifference = std::make_shared<Primitive>("SquaredDifference"); | inline const PrimitivePtr kPrimSquaredDifference = std::make_shared<Primitive>("SquaredDifference"); | ||||
| inline const PrimitivePtr kPrimAsin = std::make_shared<Primitive>("Asin"); | |||||
| inline const PrimitivePtr kPrimACos = std::make_shared<Primitive>("ACos"); | |||||
| inline const PrimitivePtr kPrimAsinGrad = std::make_shared<Primitive>("AsinGrad"); | |||||
| inline const PrimitivePtr kPrimACosGrad = std::make_shared<Primitive>("ACosGrad"); | |||||
| // Statements | // Statements | ||||
| inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return"); | inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return"); | ||||
| @@ -351,7 +355,7 @@ inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_ | |||||
| inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref"); | inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref"); | ||||
| inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value"); | ||||
| // Other primitve not used by backend but used in core; | |||||
| // Other primitive not used by backend but used in core; | |||||
| inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem"); | inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem"); | ||||
| inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J"); | inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J"); | ||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.ops.operations import _grad_ops as G | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| class NetACosGrad(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetACosGrad, self).__init__() | |||||
| self.acosGrad = G.ACosGrad() | |||||
| def construct(self, x, dy): | |||||
| return self.acosGrad(x, dy) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_acos_grad(): | |||||
| x = np.array([-0.5, 0, 0.5]).astype('float32') | |||||
| dy = np.array([1, 0, -1]).astype('float32') | |||||
| acos_grad = NetACosGrad() | |||||
| output = acos_grad(Tensor(x), Tensor(dy)) | |||||
| print(output) | |||||
| expect = -dy / np.sqrt(1 - x * x) | |||||
| assert np.allclose(output.asnumpy(), expect) | |||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| class NetACos(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetACos, self).__init__() | |||||
| self.acos = P.ACos() | |||||
| def construct(self, x): | |||||
| return self.acos(x) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_acos(): | |||||
| np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') | |||||
| input_x = Tensor(np_array) | |||||
| net = NetACos() | |||||
| output = net(input_x) | |||||
| print(output) | |||||
| expect = np.arccos(np_array) | |||||
| assert np.allclose(output.asnumpy(), expect) | |||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.ops.operations import _grad_ops as G | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| class NetAsinGrad(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetAsinGrad, self).__init__() | |||||
| self.asinGrad = G.AsinGrad() | |||||
| def construct(self, x, dy): | |||||
| return self.asinGrad(x, dy) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_asin_grad(): | |||||
| x = np.array([-0.5, 0, 0.5]).astype('float32') | |||||
| dy = np.array([1, 0, -1]).astype('float32') | |||||
| asin_grad = NetAsinGrad() | |||||
| output = asin_grad(Tensor(x), Tensor(dy)) | |||||
| print(output) | |||||
| expect = dy / np.sqrt(1 - x * x) | |||||
| assert np.allclose(output.asnumpy(), expect) | |||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| class NetAsin(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetAsin, self).__init__() | |||||
| self.asin = P.Asin() | |||||
| def construct(self, x): | |||||
| return self.asin(x) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_asin(): | |||||
| np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') | |||||
| input_x = Tensor(np_array) | |||||
| net = NetAsin() | |||||
| output = net(input_x) | |||||
| print(output) | |||||
| expect = np.arcsin(np_array) | |||||
| assert np.allclose(output.asnumpy(), expect) | |||||