From: @wangrao124 Reviewed-by: @wuxuejian,@kisnwang Signed-off-by: @wuxuejiantags/v1.2.0-rc1
| @@ -16,6 +16,7 @@ | |||
| #include <cmath> | |||
| #include <string> | |||
| #include <thread> | |||
| #include <map> | |||
| #include "backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| @@ -107,6 +108,34 @@ void ACos(const T *in, T *out, size_t start, size_t end) { | |||
| out[i] = acos(in[i]); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Atan(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = atan(in[i]); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Sin(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = sin(in[i]); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Cos(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = cos(in[i]); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Tan(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = tan(in[i]); | |||
| } | |||
| } | |||
| } // namespace | |||
| void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| @@ -134,6 +163,14 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| operate_type_ = ASIN; | |||
| } else if (kernel_name == prim::kPrimACos->name()) { | |||
| operate_type_ = ACOS; | |||
| } else if (kernel_name == prim::kPrimAtan->name()) { | |||
| operate_type_ = ATAN; | |||
| } else if (kernel_name == prim::kPrimSin->name()) { | |||
| operate_type_ = SIN; | |||
| } else if (kernel_name == prim::kPrimCos->name()) { | |||
| operate_type_ = COS; | |||
| } else if (kernel_name == prim::kPrimTan->name()) { | |||
| operate_type_ = TAN; | |||
| } | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0); | |||
| @@ -214,31 +251,18 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs | |||
| MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; | |||
| return; | |||
| } | |||
| static const std::map<OperateType, std::function<void(const T *in, T *out, size_t start, size_t end)>> | |||
| kArithmeticOpFuncMap = {{SQUARE, Square<T>}, {SIGN, Sign<T>}, | |||
| {NEG, Neg<T>}, {LOGICALNOT, LogicalNot<T>}, | |||
| {ONESLIKE, OnesLike<T>}, {ZEROSLIKE, ZerosLike<T>}, | |||
| {FLOOR, Floor<T>}, {RECIPROCAL, Reciprocal<T>}, | |||
| {GELU, Gelu<T>}, {SIN, Sin<T>}, | |||
| {COS, Cos<T>}, {TAN, Tan<T>}, | |||
| {ASIN, Asin<T>}, {ACOS, ACos<T>}, | |||
| {ATAN, Atan<T>}}; | |||
| while (start < lens) { | |||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | |||
| if (operate_type_ == SQUARE) { | |||
| threads.emplace_back(std::thread(Square<T>, input, output, start, end)); | |||
| } else if (operate_type_ == NEG) { | |||
| threads.emplace_back(std::thread(Neg<T>, input, output, start, end)); | |||
| } else if (operate_type_ == LOGICALNOT) { | |||
| threads.emplace_back(std::thread(LogicalNot<T>, input, output, start, end)); | |||
| } else if (operate_type_ == ONESLIKE) { | |||
| threads.emplace_back(std::thread(OnesLike<T>, input, output, start, end)); | |||
| } else if (operate_type_ == ZEROSLIKE) { | |||
| threads.emplace_back(std::thread(ZerosLike<T>, input, output, start, end)); | |||
| } else if (operate_type_ == SIGN) { | |||
| threads.emplace_back(std::thread(Sign<T>, input, output, start, end)); | |||
| } else if (operate_type_ == FLOOR) { | |||
| threads.emplace_back(std::thread(Floor<T>, input, output, start, end)); | |||
| } else if (operate_type_ == RECIPROCAL) { | |||
| threads.emplace_back(std::thread(Reciprocal<T>, input, output, start, end)); | |||
| } else if (operate_type_ == GELU) { | |||
| threads.emplace_back(std::thread(Gelu<T>, input, output, start, end)); | |||
| } else if (operate_type_ == ASIN) { | |||
| threads.emplace_back(std::thread(Asin<T>, input, output, start, end)); | |||
| } else if (operate_type_ == ACOS) { | |||
| threads.emplace_back(std::thread(ACos<T>, input, output, start, end)); | |||
| } | |||
| threads.emplace_back(std::thread(kArithmeticOpFuncMap.at(operate_type_), input, output, start, end)); | |||
| start += once_compute_size; | |||
| } | |||
| for (size_t i = 0; i < threads.size(); ++i) { | |||
| @@ -78,6 +78,22 @@ MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputA | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Atan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Tan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Tan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ArithmeticSelfCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -95,8 +95,13 @@ enum OperateType { | |||
| GELUGRAD, | |||
| ASIN, | |||
| ACOS, | |||
| ATAN, | |||
| ASINGRAD, | |||
| ACOSGRAD, | |||
| ATANGRAD, | |||
| SIN, | |||
| COS, | |||
| TAN, | |||
| }; | |||
| class CPUKernel : public kernel::KernelMod { | |||
| @@ -132,6 +132,27 @@ void EltWiseGradCPUKernel::ACosGrad(const T *input1, const T *input2, T *out, si | |||
| } | |||
| } | |||
| template <typename T> | |||
| void EltWiseGradCPUKernel::AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| T dividend = input2[i]; | |||
| T divisor = 1 + input1[i] * input1[i]; | |||
| if (divisor == 0) { | |||
| if (dividend == 0) { | |||
| out[i] = std::numeric_limits<T>::quiet_NaN(); | |||
| continue; | |||
| } | |||
| if (std::numeric_limits<T>::has_infinity) { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||
| } else { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||
| } | |||
| continue; | |||
| } | |||
| out[i] = dividend / divisor; | |||
| } | |||
| } | |||
| void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| @@ -153,6 +174,8 @@ void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| operate_type_ = ASINGRAD; | |||
| } else if (kernel_name == "ACosGrad") { | |||
| operate_type_ = ACOSGRAD; | |||
| } else if (kernel_name == "AtanGrad") { | |||
| operate_type_ = ATANGRAD; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Not support " << kernel_name; | |||
| } | |||
| @@ -238,6 +261,8 @@ void EltWiseGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, c | |||
| threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AsinGrad<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == ACOSGRAD) { | |||
| threads.emplace_back(std::thread(&EltWiseGradCPUKernel::ACosGrad<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == ATANGRAD) { | |||
| threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AtanGrad<T>, this, input1, input2, output, start, end)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Not support " << operate_type_; | |||
| } | |||
| @@ -54,6 +54,8 @@ class EltWiseGradCPUKernel : public CPUKernel { | |||
| void AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| template <typename T> | |||
| void ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| template <typename T> | |||
| void AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| std::vector<size_t> input_shape0_; | |||
| std::vector<size_t> input_shape1_; | |||
| std::vector<size_t> input_element_num0_; | |||
| @@ -109,6 +111,13 @@ MS_REG_CPU_KERNEL( | |||
| MS_REG_CPU_KERNEL( | |||
| ACosGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| EltWiseGradCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| AtanGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| EltWiseGradCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| AtanGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| EltWiseGradCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -388,6 +388,7 @@ inline const PrimitivePtr kPrimSign = std::make_shared<Primitive>("Sign"); | |||
| inline const PrimitivePtr kPrimACos = std::make_shared<Primitive>("ACos"); | |||
| inline const PrimitivePtr kPrimAsinGrad = std::make_shared<Primitive>("AsinGrad"); | |||
| inline const PrimitivePtr kPrimACosGrad = std::make_shared<Primitive>("ACosGrad"); | |||
| inline const PrimitivePtr kPrimAtanGrad = std::make_shared<Primitive>("AtanGrad"); | |||
| inline const PrimitivePtr kPrimFloorMod = std::make_shared<Primitive>("FloorMod"); | |||
| inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where"); | |||
| @@ -3342,7 +3342,7 @@ class Cos(PrimitiveWithInfer): | |||
| Tensor, has the same shape as `input_x`. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> cos = ops.Cos() | |||
| @@ -3379,7 +3379,7 @@ class ACos(PrimitiveWithInfer): | |||
| Tensor, has the same shape as `input_x`. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> acos = ops.ACos() | |||
| @@ -3412,7 +3412,7 @@ class Sin(PrimitiveWithInfer): | |||
| Tensor, has the same shape as `input_x`. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> sin = ops.Sin() | |||
| @@ -3449,7 +3449,7 @@ class Asin(PrimitiveWithInfer): | |||
| Tensor, has the same shape as `input_x`. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> asin = ops.Asin() | |||
| @@ -3666,7 +3666,7 @@ class Tan(PrimitiveWithInfer): | |||
| Tensor, has the same shape as `input_x`. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| ``Ascend`` ``CPU`` | |||
| Examples: | |||
| >>> tan = ops.Tan() | |||
| @@ -3704,7 +3704,7 @@ class Atan(PrimitiveWithInfer): | |||
| A Tensor, has the same type as the input. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([1.0, 0.0]), mindspore.float32) | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class NetAtanGrad(nn.Cell): | |||
| def __init__(self): | |||
| super(NetAtanGrad, self).__init__() | |||
| self.atanGrad = G.AtanGrad() | |||
| def construct(self, x, dy): | |||
| return self.atanGrad(x, dy) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_atan_grad(): | |||
| x = np.array([-0.5, 0, 0.5]).astype('float32') | |||
| dy = np.array([1, 0, -1]).astype('float32') | |||
| atan_grad = NetAtanGrad() | |||
| output = atan_grad(Tensor(x), Tensor(dy)) | |||
| print(output) | |||
| expect = dy / (1 + x * x) | |||
| assert np.allclose(output.asnumpy(), expect) | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class NetAtan(nn.Cell): | |||
| def __init__(self): | |||
| super(NetAtan, self).__init__() | |||
| self.atan = P.Atan() | |||
| def construct(self, x): | |||
| return self.atan(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_atan(): | |||
| np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') | |||
| input_x = Tensor(np_array) | |||
| net = NetAtan() | |||
| output = net(input_x) | |||
| print(output) | |||
| expect = np.arctan(np_array) | |||
| assert np.allclose(output.asnumpy(), expect) | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class NetCos(nn.Cell): | |||
| def __init__(self): | |||
| super(NetCos, self).__init__() | |||
| self.cos = P.Cos() | |||
| def construct(self, x): | |||
| return self.cos(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_cos(): | |||
| np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') | |||
| input_x = Tensor(np_array) | |||
| net = NetCos() | |||
| output = net(input_x) | |||
| print(output) | |||
| expect = np.cos(np_array) | |||
| assert np.allclose(output.asnumpy(), expect) | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class NetSin(nn.Cell): | |||
| def __init__(self): | |||
| super(NetSin, self).__init__() | |||
| self.sin = P.Sin() | |||
| def construct(self, x): | |||
| return self.sin(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_sin(): | |||
| np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') | |||
| input_x = Tensor(np_array) | |||
| net = NetSin() | |||
| output = net(input_x) | |||
| print(output) | |||
| expect = np.sin(np_array) | |||
| assert np.allclose(output.asnumpy(), expect) | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class NetTan(nn.Cell): | |||
| def __init__(self): | |||
| super(NetTan, self).__init__() | |||
| self.tan = P.Tan() | |||
| def construct(self, x): | |||
| return self.tan(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_tan(): | |||
| np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') | |||
| input_x = Tensor(np_array) | |||
| net = NetTan() | |||
| output = net(input_x) | |||
| print(output) | |||
| expect = np.tan(np_array) | |||
| assert np.allclose(output.asnumpy(), expect) | |||