From: @wangrao124 Reviewed-by: @wuxuejian,@kisnwang Signed-off-by: @wuxuejiantags/v1.2.0-rc1
| @@ -16,6 +16,7 @@ | |||||
| #include <cmath> | #include <cmath> | ||||
| #include <string> | #include <string> | ||||
| #include <thread> | #include <thread> | ||||
| #include <map> | |||||
| #include "backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h" | #include "backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h" | ||||
| #include "runtime/device/cpu/cpu_device_address.h" | #include "runtime/device/cpu/cpu_device_address.h" | ||||
| @@ -107,6 +108,34 @@ void ACos(const T *in, T *out, size_t start, size_t end) { | |||||
| out[i] = acos(in[i]); | out[i] = acos(in[i]); | ||||
| } | } | ||||
| } | } | ||||
| template <typename T> | |||||
| void Atan(const T *in, T *out, size_t start, size_t end) { | |||||
| for (size_t i = start; i < end; i++) { | |||||
| out[i] = atan(in[i]); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void Sin(const T *in, T *out, size_t start, size_t end) { | |||||
| for (size_t i = start; i < end; i++) { | |||||
| out[i] = sin(in[i]); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void Cos(const T *in, T *out, size_t start, size_t end) { | |||||
| for (size_t i = start; i < end; i++) { | |||||
| out[i] = cos(in[i]); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| void Tan(const T *in, T *out, size_t start, size_t end) { | |||||
| for (size_t i = start; i < end; i++) { | |||||
| out[i] = tan(in[i]); | |||||
| } | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| @@ -134,6 +163,14 @@ void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| operate_type_ = ASIN; | operate_type_ = ASIN; | ||||
| } else if (kernel_name == prim::kPrimACos->name()) { | } else if (kernel_name == prim::kPrimACos->name()) { | ||||
| operate_type_ = ACOS; | operate_type_ = ACOS; | ||||
| } else if (kernel_name == prim::kPrimAtan->name()) { | |||||
| operate_type_ = ATAN; | |||||
| } else if (kernel_name == prim::kPrimSin->name()) { | |||||
| operate_type_ = SIN; | |||||
| } else if (kernel_name == prim::kPrimCos->name()) { | |||||
| operate_type_ = COS; | |||||
| } else if (kernel_name == prim::kPrimTan->name()) { | |||||
| operate_type_ = TAN; | |||||
| } | } | ||||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | ||||
| target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0); | target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0); | ||||
| @@ -214,31 +251,18 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs | |||||
| MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; | MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; | ||||
| return; | return; | ||||
| } | } | ||||
| static const std::map<OperateType, std::function<void(const T *in, T *out, size_t start, size_t end)>> | |||||
| kArithmeticOpFuncMap = {{SQUARE, Square<T>}, {SIGN, Sign<T>}, | |||||
| {NEG, Neg<T>}, {LOGICALNOT, LogicalNot<T>}, | |||||
| {ONESLIKE, OnesLike<T>}, {ZEROSLIKE, ZerosLike<T>}, | |||||
| {FLOOR, Floor<T>}, {RECIPROCAL, Reciprocal<T>}, | |||||
| {GELU, Gelu<T>}, {SIN, Sin<T>}, | |||||
| {COS, Cos<T>}, {TAN, Tan<T>}, | |||||
| {ASIN, Asin<T>}, {ACOS, ACos<T>}, | |||||
| {ATAN, Atan<T>}}; | |||||
| while (start < lens) { | while (start < lens) { | ||||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | ||||
| if (operate_type_ == SQUARE) { | |||||
| threads.emplace_back(std::thread(Square<T>, input, output, start, end)); | |||||
| } else if (operate_type_ == NEG) { | |||||
| threads.emplace_back(std::thread(Neg<T>, input, output, start, end)); | |||||
| } else if (operate_type_ == LOGICALNOT) { | |||||
| threads.emplace_back(std::thread(LogicalNot<T>, input, output, start, end)); | |||||
| } else if (operate_type_ == ONESLIKE) { | |||||
| threads.emplace_back(std::thread(OnesLike<T>, input, output, start, end)); | |||||
| } else if (operate_type_ == ZEROSLIKE) { | |||||
| threads.emplace_back(std::thread(ZerosLike<T>, input, output, start, end)); | |||||
| } else if (operate_type_ == SIGN) { | |||||
| threads.emplace_back(std::thread(Sign<T>, input, output, start, end)); | |||||
| } else if (operate_type_ == FLOOR) { | |||||
| threads.emplace_back(std::thread(Floor<T>, input, output, start, end)); | |||||
| } else if (operate_type_ == RECIPROCAL) { | |||||
| threads.emplace_back(std::thread(Reciprocal<T>, input, output, start, end)); | |||||
| } else if (operate_type_ == GELU) { | |||||
| threads.emplace_back(std::thread(Gelu<T>, input, output, start, end)); | |||||
| } else if (operate_type_ == ASIN) { | |||||
| threads.emplace_back(std::thread(Asin<T>, input, output, start, end)); | |||||
| } else if (operate_type_ == ACOS) { | |||||
| threads.emplace_back(std::thread(ACos<T>, input, output, start, end)); | |||||
| } | |||||
| threads.emplace_back(std::thread(kArithmeticOpFuncMap.at(operate_type_), input, output, start, end)); | |||||
| start += once_compute_size; | start += once_compute_size; | ||||
| } | } | ||||
| for (size_t i = 0; i < threads.size(); ++i) { | for (size_t i = 0; i < threads.size(); ++i) { | ||||
| @@ -78,6 +78,22 @@ MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputA | |||||
| ArithmeticSelfCPUKernel); | ArithmeticSelfCPUKernel); | ||||
| MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | ||||
| ArithmeticSelfCPUKernel); | ArithmeticSelfCPUKernel); | ||||
| MS_REG_CPU_KERNEL(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ArithmeticSelfCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Atan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| ArithmeticSelfCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ArithmeticSelfCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| ArithmeticSelfCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ArithmeticSelfCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| ArithmeticSelfCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Tan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ArithmeticSelfCPUKernel); | |||||
| MS_REG_CPU_KERNEL(Tan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| ArithmeticSelfCPUKernel); | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -95,8 +95,13 @@ enum OperateType { | |||||
| GELUGRAD, | GELUGRAD, | ||||
| ASIN, | ASIN, | ||||
| ACOS, | ACOS, | ||||
| ATAN, | |||||
| ASINGRAD, | ASINGRAD, | ||||
| ACOSGRAD, | ACOSGRAD, | ||||
| ATANGRAD, | |||||
| SIN, | |||||
| COS, | |||||
| TAN, | |||||
| }; | }; | ||||
| class CPUKernel : public kernel::KernelMod { | class CPUKernel : public kernel::KernelMod { | ||||
| @@ -132,6 +132,27 @@ void EltWiseGradCPUKernel::ACosGrad(const T *input1, const T *input2, T *out, si | |||||
| } | } | ||||
| } | } | ||||
| template <typename T> | |||||
| void EltWiseGradCPUKernel::AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||||
| for (size_t i = start; i < end; i++) { | |||||
| T dividend = input2[i]; | |||||
| T divisor = 1 + input1[i] * input1[i]; | |||||
| if (divisor == 0) { | |||||
| if (dividend == 0) { | |||||
| out[i] = std::numeric_limits<T>::quiet_NaN(); | |||||
| continue; | |||||
| } | |||||
| if (std::numeric_limits<T>::has_infinity) { | |||||
| out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||||
| } else { | |||||
| out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||||
| } | |||||
| continue; | |||||
| } | |||||
| out[i] = dividend / divisor; | |||||
| } | |||||
| } | |||||
| void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | ||||
| @@ -153,6 +174,8 @@ void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| operate_type_ = ASINGRAD; | operate_type_ = ASINGRAD; | ||||
| } else if (kernel_name == "ACosGrad") { | } else if (kernel_name == "ACosGrad") { | ||||
| operate_type_ = ACOSGRAD; | operate_type_ = ACOSGRAD; | ||||
| } else if (kernel_name == "AtanGrad") { | |||||
| operate_type_ = ATANGRAD; | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Not support " << kernel_name; | MS_LOG(EXCEPTION) << "Not support " << kernel_name; | ||||
| } | } | ||||
| @@ -238,6 +261,8 @@ void EltWiseGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, c | |||||
| threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AsinGrad<T>, this, input1, input2, output, start, end)); | threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AsinGrad<T>, this, input1, input2, output, start, end)); | ||||
| } else if (operate_type_ == ACOSGRAD) { | } else if (operate_type_ == ACOSGRAD) { | ||||
| threads.emplace_back(std::thread(&EltWiseGradCPUKernel::ACosGrad<T>, this, input1, input2, output, start, end)); | threads.emplace_back(std::thread(&EltWiseGradCPUKernel::ACosGrad<T>, this, input1, input2, output, start, end)); | ||||
| } else if (operate_type_ == ATANGRAD) { | |||||
| threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AtanGrad<T>, this, input1, input2, output, start, end)); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Not support " << operate_type_; | MS_LOG(EXCEPTION) << "Not support " << operate_type_; | ||||
| } | } | ||||
| @@ -54,6 +54,8 @@ class EltWiseGradCPUKernel : public CPUKernel { | |||||
| void AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | void AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | ||||
| template <typename T> | template <typename T> | ||||
| void ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | void ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | ||||
| template <typename T> | |||||
| void AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||||
| std::vector<size_t> input_shape0_; | std::vector<size_t> input_shape0_; | ||||
| std::vector<size_t> input_shape1_; | std::vector<size_t> input_shape1_; | ||||
| std::vector<size_t> input_element_num0_; | std::vector<size_t> input_element_num0_; | ||||
| @@ -109,6 +111,13 @@ MS_REG_CPU_KERNEL( | |||||
| MS_REG_CPU_KERNEL( | MS_REG_CPU_KERNEL( | ||||
| ACosGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | ACosGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | ||||
| EltWiseGradCPUKernel); | EltWiseGradCPUKernel); | ||||
| MS_REG_CPU_KERNEL( | |||||
| AtanGrad, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| EltWiseGradCPUKernel); | |||||
| MS_REG_CPU_KERNEL( | |||||
| AtanGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| EltWiseGradCPUKernel); | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -388,6 +388,7 @@ inline const PrimitivePtr kPrimSign = std::make_shared<Primitive>("Sign"); | |||||
| inline const PrimitivePtr kPrimACos = std::make_shared<Primitive>("ACos"); | inline const PrimitivePtr kPrimACos = std::make_shared<Primitive>("ACos"); | ||||
| inline const PrimitivePtr kPrimAsinGrad = std::make_shared<Primitive>("AsinGrad"); | inline const PrimitivePtr kPrimAsinGrad = std::make_shared<Primitive>("AsinGrad"); | ||||
| inline const PrimitivePtr kPrimACosGrad = std::make_shared<Primitive>("ACosGrad"); | inline const PrimitivePtr kPrimACosGrad = std::make_shared<Primitive>("ACosGrad"); | ||||
| inline const PrimitivePtr kPrimAtanGrad = std::make_shared<Primitive>("AtanGrad"); | |||||
| inline const PrimitivePtr kPrimFloorMod = std::make_shared<Primitive>("FloorMod"); | inline const PrimitivePtr kPrimFloorMod = std::make_shared<Primitive>("FloorMod"); | ||||
| inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where"); | inline const PrimitivePtr kPrimWhere = std::make_shared<Primitive>("Where"); | ||||
| @@ -3342,7 +3342,7 @@ class Cos(PrimitiveWithInfer): | |||||
| Tensor, has the same shape as `input_x`. | Tensor, has the same shape as `input_x`. | ||||
| Supported Platforms: | Supported Platforms: | ||||
| ``Ascend`` ``GPU`` | |||||
| ``Ascend`` ``GPU`` ``CPU`` | |||||
| Examples: | Examples: | ||||
| >>> cos = ops.Cos() | >>> cos = ops.Cos() | ||||
| @@ -3379,7 +3379,7 @@ class ACos(PrimitiveWithInfer): | |||||
| Tensor, has the same shape as `input_x`. | Tensor, has the same shape as `input_x`. | ||||
| Supported Platforms: | Supported Platforms: | ||||
| ``Ascend`` ``GPU`` | |||||
| ``Ascend`` ``GPU`` ``CPU`` | |||||
| Examples: | Examples: | ||||
| >>> acos = ops.ACos() | >>> acos = ops.ACos() | ||||
| @@ -3412,7 +3412,7 @@ class Sin(PrimitiveWithInfer): | |||||
| Tensor, has the same shape as `input_x`. | Tensor, has the same shape as `input_x`. | ||||
| Supported Platforms: | Supported Platforms: | ||||
| ``Ascend`` ``GPU`` | |||||
| ``Ascend`` ``GPU`` ``CPU`` | |||||
| Examples: | Examples: | ||||
| >>> sin = ops.Sin() | >>> sin = ops.Sin() | ||||
| @@ -3449,7 +3449,7 @@ class Asin(PrimitiveWithInfer): | |||||
| Tensor, has the same shape as `input_x`. | Tensor, has the same shape as `input_x`. | ||||
| Supported Platforms: | Supported Platforms: | ||||
| ``Ascend`` ``GPU`` | |||||
| ``Ascend`` ``GPU`` ``CPU`` | |||||
| Examples: | Examples: | ||||
| >>> asin = ops.Asin() | >>> asin = ops.Asin() | ||||
| @@ -3666,7 +3666,7 @@ class Tan(PrimitiveWithInfer): | |||||
| Tensor, has the same shape as `input_x`. | Tensor, has the same shape as `input_x`. | ||||
| Supported Platforms: | Supported Platforms: | ||||
| ``Ascend`` | |||||
| ``Ascend`` ``CPU`` | |||||
| Examples: | Examples: | ||||
| >>> tan = ops.Tan() | >>> tan = ops.Tan() | ||||
| @@ -3704,7 +3704,7 @@ class Atan(PrimitiveWithInfer): | |||||
| A Tensor, has the same type as the input. | A Tensor, has the same type as the input. | ||||
| Supported Platforms: | Supported Platforms: | ||||
| ``Ascend`` ``GPU`` | |||||
| ``Ascend`` ``GPU`` ``CPU`` | |||||
| Examples: | Examples: | ||||
| >>> input_x = Tensor(np.array([1.0, 0.0]), mindspore.float32) | >>> input_x = Tensor(np.array([1.0, 0.0]), mindspore.float32) | ||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.ops.operations import _grad_ops as G | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| class NetAtanGrad(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetAtanGrad, self).__init__() | |||||
| self.atanGrad = G.AtanGrad() | |||||
| def construct(self, x, dy): | |||||
| return self.atanGrad(x, dy) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_atan_grad(): | |||||
| x = np.array([-0.5, 0, 0.5]).astype('float32') | |||||
| dy = np.array([1, 0, -1]).astype('float32') | |||||
| atan_grad = NetAtanGrad() | |||||
| output = atan_grad(Tensor(x), Tensor(dy)) | |||||
| print(output) | |||||
| expect = dy / (1 + x * x) | |||||
| assert np.allclose(output.asnumpy(), expect) | |||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| class NetAtan(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetAtan, self).__init__() | |||||
| self.atan = P.Atan() | |||||
| def construct(self, x): | |||||
| return self.atan(x) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_atan(): | |||||
| np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') | |||||
| input_x = Tensor(np_array) | |||||
| net = NetAtan() | |||||
| output = net(input_x) | |||||
| print(output) | |||||
| expect = np.arctan(np_array) | |||||
| assert np.allclose(output.asnumpy(), expect) | |||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| class NetCos(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetCos, self).__init__() | |||||
| self.cos = P.Cos() | |||||
| def construct(self, x): | |||||
| return self.cos(x) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_cos(): | |||||
| np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') | |||||
| input_x = Tensor(np_array) | |||||
| net = NetCos() | |||||
| output = net(input_x) | |||||
| print(output) | |||||
| expect = np.cos(np_array) | |||||
| assert np.allclose(output.asnumpy(), expect) | |||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| class NetSin(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetSin, self).__init__() | |||||
| self.sin = P.Sin() | |||||
| def construct(self, x): | |||||
| return self.sin(x) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_sin(): | |||||
| np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') | |||||
| input_x = Tensor(np_array) | |||||
| net = NetSin() | |||||
| output = net(input_x) | |||||
| print(output) | |||||
| expect = np.sin(np_array) | |||||
| assert np.allclose(output.asnumpy(), expect) | |||||
| @@ -0,0 +1,46 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| class NetTan(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetTan, self).__init__() | |||||
| self.tan = P.Tan() | |||||
| def construct(self, x): | |||||
| return self.tan(x) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_tan(): | |||||
| np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') | |||||
| input_x = Tensor(np_array) | |||||
| net = NetTan() | |||||
| output = net(input_x) | |||||
| print(output) | |||||
| expect = np.tan(np_array) | |||||
| assert np.allclose(output.asnumpy(), expect) | |||||