| @@ -16,6 +16,7 @@ | |||
| #include <cmath> | |||
| #include <string> | |||
| #include <thread> | |||
| #include <map> | |||
| #include "backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| @@ -235,45 +236,40 @@ void ArithmeticCPUKernel::LessEqual(const T *input1, const T *input2, bool *out, | |||
| } | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::Atan2(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = atan2(input1[idx[0]], input2[idx[1]]); | |||
| } | |||
| } | |||
| static const std::map<std::string, OperateType> kArithmeticBinOpTypeMap = { | |||
| {prim::kPrimGreater->name(), GREATER}, | |||
| {prim::kPrimAdd->name(), ADD}, | |||
| {prim::kPrimGreaterEqual->name(), GREATEREQUAL}, | |||
| {prim::kPrimSub->name(), SUB}, | |||
| {prim::kPrimLogicalAnd->name(), LOGICALAND}, | |||
| {prim::kPrimMul->name(), MUL}, | |||
| {prim::kPrimLessEqual->name(), LESSEQUAL}, | |||
| {prim::kPrimDiv->name(), DIV}, | |||
| {prim::kPrimLogicalOr->name(), LOGICALOR}, | |||
| {prim::kPrimMod->name(), MOD}, | |||
| {prim::kPrimAssignAdd->name(), ASSIGNADD}, | |||
| {prim::kPrimPow->name(), POW}, | |||
| {prim::kPrimFloorDiv->name(), FLOORDIV}, | |||
| {prim::kPrimLess->name(), LESS}, | |||
| {prim::kPrimNotEqual->name(), NOTEQUAL}, | |||
| {prim::kPrimAtan2->name(), ATAN2}, | |||
| {prim::kPrimRealDiv->name(), REALDIV}, | |||
| {prim::kPrimEqual->name(), EQUAL}, | |||
| {prim::kPrimSquaredDifference->name(), SQUAREDDIFFERENCE}}; | |||
| void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (kernel_name == prim::kPrimAdd->name()) { | |||
| operate_type_ = ADD; | |||
| } else if (kernel_name == prim::kPrimSub->name()) { | |||
| operate_type_ = SUB; | |||
| } else if (kernel_name == prim::kPrimMul->name()) { | |||
| operate_type_ = MUL; | |||
| } else if (kernel_name == prim::kPrimRealDiv->name()) { | |||
| operate_type_ = REALDIV; | |||
| } else if (kernel_name == prim::kPrimDiv->name()) { | |||
| operate_type_ = DIV; | |||
| } else if (kernel_name == prim::kPrimFloorDiv->name()) { | |||
| operate_type_ = FLOORDIV; | |||
| } else if (kernel_name == prim::kPrimMod->name()) { | |||
| operate_type_ = MOD; | |||
| } else if (kernel_name == prim::kPrimPow->name()) { | |||
| operate_type_ = POW; | |||
| } else if (kernel_name == prim::kPrimLess->name()) { | |||
| operate_type_ = LESS; | |||
| } else if (kernel_name == prim::kPrimEqual->name()) { | |||
| operate_type_ = EQUAL; | |||
| } else if (kernel_name == prim::kPrimNotEqual->name()) { | |||
| operate_type_ = NOTEQUAL; | |||
| } else if (kernel_name == prim::kPrimGreater->name()) { | |||
| operate_type_ = GREATER; | |||
| } else if (kernel_name == prim::kPrimGreaterEqual->name()) { | |||
| operate_type_ = GREATEREQUAL; | |||
| } else if (kernel_name == prim::kPrimLessEqual->name()) { | |||
| operate_type_ = LESSEQUAL; | |||
| } else if (kernel_name == prim::kPrimLogicalAnd->name()) { | |||
| operate_type_ = LOGICALAND; | |||
| } else if (kernel_name == prim::kPrimLogicalOr->name()) { | |||
| operate_type_ = LOGICALOR; | |||
| } else if (kernel_name == prim::kPrimAssignAdd->name()) { | |||
| operate_type_ = ASSIGNADD; | |||
| } else if (kernel_name == prim::kPrimSquaredDifference->name()) { | |||
| operate_type_ = SQUAREDDIFFERENCE; | |||
| if (kArithmeticBinOpTypeMap.find(kernel_name) != kArithmeticBinOpTypeMap.end()) { | |||
| operate_type_ = kArithmeticBinOpTypeMap.at(kernel_name); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Not support " << kernel_name; | |||
| } | |||
| @@ -448,6 +444,8 @@ void ArithmeticCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, co | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Pow<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == ASSIGNADD) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::AssignAdd<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == ATAN2) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Atan2<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == SQUAREDDIFFERENCE) { | |||
| threads.emplace_back( | |||
| std::thread(&ArithmeticCPUKernel::SquaredDifference<T>, this, input1, input2, output, start, end)); | |||
| @@ -58,6 +58,8 @@ class ArithmeticCPUKernel : public CPUKernel { | |||
| template <typename T> | |||
| void AssignAdd(T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| template <typename T> | |||
| void Atan2(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| template <typename T> | |||
| void Less(const T *input1, const T *input2, bool *out, size_t start, size_t end); | |||
| template <typename T> | |||
| void Equal(const T *input1, const T *input2, bool *out, size_t start, size_t end); | |||
| @@ -279,6 +281,10 @@ MS_REG_CPU_KERNEL( | |||
| MS_REG_CPU_KERNEL( | |||
| LogicalOr, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), | |||
| ArithmeticCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| Atan2, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -136,42 +136,68 @@ void Tan(const T *in, T *out, size_t start, size_t end) { | |||
| out[i] = tan(in[i]); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Sinh(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = sinh(in[i]); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Cosh(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = cosh(in[i]); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Asinh(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = asinh(in[i]); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Acosh(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = acosh(in[i]); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Atanh(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = atanh(in[i]); | |||
| } | |||
| } | |||
| } // namespace | |||
| static const std::map<std::string, OperateType> kArithmeticOpTypeMap = {{prim::kPrimNeg->name(), NEG}, | |||
| {prim::kPrimSquare->name(), SQUARE}, | |||
| {prim::kPrimOnesLike->name(), ONESLIKE}, | |||
| {prim::kPrimZerosLike->name(), ZEROSLIKE}, | |||
| {prim::kPrimLogicalNot->name(), LOGICALNOT}, | |||
| {prim::kPrimSign->name(), SIGN}, | |||
| {prim::kPrimFloor->name(), FLOOR}, | |||
| {prim::kPrimReciprocal->name(), RECIPROCAL}, | |||
| {prim::kPrimGeLU->name(), GELU}, | |||
| {prim::kPrimAsin->name(), ASIN}, | |||
| {prim::kPrimACos->name(), ACOS}, | |||
| {prim::kPrimAtan->name(), ATAN}, | |||
| {prim::kPrimSin->name(), SIN}, | |||
| {prim::kPrimCos->name(), COS}, | |||
| {prim::kPrimTan->name(), TAN}, | |||
| {prim::kPrimSinh->name(), SINH}, | |||
| {prim::kPrimCosh->name(), COSH}, | |||
| {prim::kPrimAsinh->name(), ASINH}, | |||
| {prim::kPrimAcosh->name(), ACOSH}, | |||
| {prim::kPrimAtanh->name(), ATANH}}; | |||
| void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (kernel_name == prim::kPrimSquare->name()) { | |||
| operate_type_ = SQUARE; | |||
| } else if (kernel_name == prim::kPrimOnesLike->name()) { | |||
| operate_type_ = ONESLIKE; | |||
| } else if (kernel_name == prim::kPrimZerosLike->name()) { | |||
| operate_type_ = ZEROSLIKE; | |||
| } else if (kernel_name == prim::kPrimNeg->name()) { | |||
| operate_type_ = NEG; | |||
| } else if (kernel_name == prim::kPrimLogicalNot->name()) { | |||
| operate_type_ = LOGICALNOT; | |||
| } else if (kernel_name == prim::kPrimSign->name()) { | |||
| operate_type_ = SIGN; | |||
| } else if (kernel_name == prim::kPrimFloor->name()) { | |||
| operate_type_ = FLOOR; | |||
| } else if (kernel_name == prim::kPrimReciprocal->name()) { | |||
| operate_type_ = RECIPROCAL; | |||
| } else if (kernel_name == prim::kPrimGeLU->name()) { | |||
| operate_type_ = GELU; | |||
| } else if (kernel_name == prim::kPrimAsin->name()) { | |||
| operate_type_ = ASIN; | |||
| } else if (kernel_name == prim::kPrimACos->name()) { | |||
| operate_type_ = ACOS; | |||
| } else if (kernel_name == prim::kPrimAtan->name()) { | |||
| operate_type_ = ATAN; | |||
| } else if (kernel_name == prim::kPrimSin->name()) { | |||
| operate_type_ = SIN; | |||
| } else if (kernel_name == prim::kPrimCos->name()) { | |||
| operate_type_ = COS; | |||
| } else if (kernel_name == prim::kPrimTan->name()) { | |||
| operate_type_ = TAN; | |||
| } | |||
| operate_type_ = kArithmeticOpTypeMap.at(kernel_name); | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0); | |||
| } | |||
| @@ -259,7 +285,10 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs | |||
| {GELU, Gelu<T>}, {SIN, Sin<T>}, | |||
| {COS, Cos<T>}, {TAN, Tan<T>}, | |||
| {ASIN, Asin<T>}, {ACOS, ACos<T>}, | |||
| {ATAN, Atan<T>}}; | |||
| {ATAN, Atan<T>}, {SINH, Sinh<T>}, | |||
| {COSH, Cosh<T>}, {ASINH, Asinh<T>}, | |||
| {ACOSH, Acosh<T>}, {ATANH, Atanh<T>}}; | |||
| while (start < lens) { | |||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | |||
| threads.emplace_back(std::thread(kArithmeticOpFuncMap.at(operate_type_), input, output, start, end)); | |||
| @@ -72,27 +72,25 @@ MS_REG_CPU_KERNEL(LogicalNot, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutp | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Asin, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Atan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Tan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Tan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| MS_REG_CPU_KERNEL(Sinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| MS_REG_CPU_KERNEL(Atanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ArithmeticSelfCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -102,6 +102,14 @@ enum OperateType { | |||
| SIN, | |||
| COS, | |||
| TAN, | |||
| SINH, | |||
| COSH, | |||
| ASINH, | |||
| ACOSH, | |||
| ATANH, | |||
| ASINHGRAD, | |||
| ACOSHGRAD, | |||
| ATAN2, | |||
| }; | |||
| class CPUKernel : public kernel::KernelMod { | |||
| @@ -153,6 +153,48 @@ void EltWiseGradCPUKernel::AtanGrad(const T *input1, const T *input2, T *out, si | |||
| } | |||
| } | |||
| template <typename T> | |||
| void EltWiseGradCPUKernel::AsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| T dividend = input2[i]; | |||
| T divisor = sqrt(1 + input1[i] * input1[i]); | |||
| if (divisor == 0) { | |||
| if (dividend == 0) { | |||
| out[i] = std::numeric_limits<T>::quiet_NaN(); | |||
| continue; | |||
| } | |||
| if (std::numeric_limits<T>::has_infinity) { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||
| } else { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||
| } | |||
| continue; | |||
| } | |||
| out[i] = dividend / divisor; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void EltWiseGradCPUKernel::AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| T dividend = input2[i]; | |||
| T divisor = sqrt(input1[i] * input1[i] - 1); | |||
| if (divisor == 0) { | |||
| if (dividend == 0) { | |||
| out[i] = std::numeric_limits<T>::quiet_NaN(); | |||
| continue; | |||
| } | |||
| if (std::numeric_limits<T>::has_infinity) { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||
| } else { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||
| } | |||
| continue; | |||
| } | |||
| out[i] = dividend / divisor; | |||
| } | |||
| } | |||
| void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| @@ -176,6 +218,10 @@ void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| operate_type_ = ACOSGRAD; | |||
| } else if (kernel_name == "AtanGrad") { | |||
| operate_type_ = ATANGRAD; | |||
| } else if (kernel_name == "AsinhGrad") { | |||
| operate_type_ = ASINHGRAD; | |||
| } else if (kernel_name == "AcoshGrad") { | |||
| operate_type_ = ACOSHGRAD; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Not support " << kernel_name; | |||
| } | |||
| @@ -263,6 +309,10 @@ void EltWiseGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, c | |||
| threads.emplace_back(std::thread(&EltWiseGradCPUKernel::ACosGrad<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == ATANGRAD) { | |||
| threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AtanGrad<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == ASINHGRAD) { | |||
| threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AsinhGrad<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == ACOSHGRAD) { | |||
| threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AcoshGrad<T>, this, input1, input2, output, start, end)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Not support " << operate_type_; | |||
| } | |||
| @@ -56,6 +56,10 @@ class EltWiseGradCPUKernel : public CPUKernel { | |||
| void ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| template <typename T> | |||
| void AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| template <typename T> | |||
| void AsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| template <typename T> | |||
| void AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| std::vector<size_t> input_shape0_; | |||
| std::vector<size_t> input_shape1_; | |||
| std::vector<size_t> input_element_num0_; | |||
| @@ -101,22 +105,21 @@ MS_REG_CPU_KERNEL( | |||
| AsinGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| EltWiseGradCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| AsinGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| EltWiseGradCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| ACosGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| EltWiseGradCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| ACosGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| AtanGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| EltWiseGradCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| AtanGrad, | |||
| AsinhGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| EltWiseGradCPUKernel); | |||
| MS_REG_CPU_KERNEL( | |||
| AtanGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| AcoshGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| EltWiseGradCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -195,9 +195,15 @@ inline const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoft | |||
| inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad"); | |||
| inline const PrimitivePtr kPrimLstm = std::make_shared<Primitive>("Lstm"); | |||
| inline const PrimitivePtr kPrimTan = std::make_shared<Primitive>("Tan"); | |||
| inline const PrimitivePtr kPrimAtan2 = std::make_shared<Primitive>("Atan2"); | |||
| inline const PrimitivePtr kPrimAtan = std::make_shared<Primitive>("Atan"); | |||
| inline const PrimitivePtr kPrimAsin = std::make_shared<Primitive>("Asin"); | |||
| inline const PrimitivePtr kPrimSinh = std::make_shared<Primitive>("Sinh"); | |||
| inline const PrimitivePtr kPrimCosh = std::make_shared<Primitive>("Cosh"); | |||
| inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh"); | |||
| inline const PrimitivePtr kPrimAsinh = std::make_shared<Primitive>("Asinh"); | |||
| inline const PrimitivePtr kPrimAcosh = std::make_shared<Primitive>("Acosh"); | |||
| inline const PrimitivePtr kPrimAtanh = std::make_shared<Primitive>("Atanh"); | |||
| inline const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad"); | |||
| inline const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling"); | |||
| inline const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad"); | |||
| @@ -2561,7 +2561,7 @@ class Acosh(PrimitiveWithInfer): | |||
| TypeError: If `input_x` is not a Tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> acosh = ops.Acosh() | |||
| @@ -2597,7 +2597,7 @@ class Cosh(PrimitiveWithInfer): | |||
| TypeError: If `input_x` is not a Tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| ``Ascend`` ``CPU`` | |||
| Examples: | |||
| >>> cosh = ops.Cosh() | |||
| @@ -2638,7 +2638,7 @@ class Asinh(PrimitiveWithInfer): | |||
| TypeError: If `input_x` is not a Tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` ``GPU`` | |||
| ``Ascend`` ``GPU`` ``CPU`` | |||
| Examples: | |||
| >>> asinh = ops.Asinh() | |||
| @@ -2671,7 +2671,7 @@ class Sinh(PrimitiveWithInfer): | |||
| Tensor, has the same shape as `input_x`. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| ``Ascend`` ``CPU`` | |||
| Examples: | |||
| >>> sinh = ops.Sinh() | |||
| @@ -3886,7 +3886,7 @@ class Atanh(PrimitiveWithInfer): | |||
| TypeError: If `input_x` is not a Tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| ``Ascend`` ``CPU`` | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([1.047, 0.785]), mindspore.float32) | |||
| @@ -3931,7 +3931,7 @@ class Atan2(_MathBinaryOp): | |||
| TypeError: If `input_x` or `input_y` is not a Tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| ``Ascend`` ``CPU`` | |||
| Examples: | |||
| >>> input_x = Tensor(np.array([0, 1]), mindspore.float32) | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class NetAcoshGrad(nn.Cell): | |||
| def __init__(self): | |||
| super(NetAcoshGrad, self).__init__() | |||
| self.acoshGrad = G.AcoshGrad() | |||
| def construct(self, x, dy): | |||
| return self.acoshGrad(x, dy) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_acosh_grad(): | |||
| x = np.array([5, 4, 3]).astype('float32') | |||
| dy = np.array([1, 0, -1]).astype('float32') | |||
| acosh_grad = NetAcoshGrad() | |||
| output = acosh_grad(Tensor(x), Tensor(dy)) | |||
| print(output) | |||
| expect = dy / np.sqrt(x * x - 1) | |||
| assert np.allclose(output.asnumpy(), expect) | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class NetAcosh(nn.Cell): | |||
| def __init__(self): | |||
| super(NetAcosh, self).__init__() | |||
| self.acosh = P.Acosh() | |||
| def construct(self, x): | |||
| return self.acosh(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_acosh(): | |||
| np_array = np.array([1, 2, 3, 4, 5]).astype('float32') | |||
| input_x = Tensor(np_array) | |||
| net = NetAcosh() | |||
| output = net(input_x) | |||
| print(output) | |||
| expect = np.arccosh(np_array) | |||
| assert np.allclose(output.asnumpy(), expect) | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class NetAsinhGrad(nn.Cell): | |||
| def __init__(self): | |||
| super(NetAsinhGrad, self).__init__() | |||
| self.asinhGrad = G.AsinhGrad() | |||
| def construct(self, x, dy): | |||
| return self.asinhGrad(x, dy) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_asinh_grad(): | |||
| x = np.array([-0.5, 0, 0.5]).astype('float32') | |||
| dy = np.array([1, 0, -1]).astype('float32') | |||
| asinh_grad = NetAsinhGrad() | |||
| output = asinh_grad(Tensor(x), Tensor(dy)) | |||
| print(output) | |||
| expect = dy / np.sqrt(1 + x * x) | |||
| assert np.allclose(output.asnumpy(), expect) | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class NetAsinh(nn.Cell): | |||
| def __init__(self): | |||
| super(NetAsinh, self).__init__() | |||
| self.asinh = P.Asinh() | |||
| def construct(self, x): | |||
| return self.asinh(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_asinh(): | |||
| np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') | |||
| input_x = Tensor(np_array) | |||
| net = NetAsinh() | |||
| output = net(input_x) | |||
| print(output) | |||
| expect = np.arcsinh(np_array) | |||
| assert np.allclose(output.asnumpy(), expect) | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class NetAtan2(nn.Cell): | |||
| def __init__(self): | |||
| super(NetAtan2, self).__init__() | |||
| self.atan2 = P.Atan2() | |||
| def construct(self, x, y): | |||
| return self.atan2(x, y) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_atan2(): | |||
| np_array = np.array([1, 2, 3, 4, 5]).astype('float32') | |||
| input_x = Tensor(np_array) | |||
| net = NetAtan2() | |||
| output = net(input_x, input_x) | |||
| print(output) | |||
| expect = np.arctan2(np_array, np_array) | |||
| assert np.allclose(output.asnumpy(), expect) | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class NetAtanh(nn.Cell): | |||
| def __init__(self): | |||
| super(NetAtanh, self).__init__() | |||
| self.atanh = P.Atanh() | |||
| def construct(self, x): | |||
| return self.atanh(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_atanh(): | |||
| np_array = np.array([-0.5, 0, 0.5]).astype('float32') | |||
| input_x = Tensor(np_array) | |||
| net = NetAtanh() | |||
| output = net(input_x) | |||
| print(output) | |||
| expect = np.arctanh(np_array) | |||
| assert np.allclose(output.asnumpy(), expect) | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class NetCosh(nn.Cell): | |||
| def __init__(self): | |||
| super(NetCosh, self).__init__() | |||
| self.cosh = P.Cosh() | |||
| def construct(self, x): | |||
| return self.cosh(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_cosh(): | |||
| np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') | |||
| input_x = Tensor(np_array) | |||
| net = NetCosh() | |||
| output = net(input_x) | |||
| print(output) | |||
| expect = np.cosh(np_array) | |||
| assert np.allclose(output.asnumpy(), expect) | |||
| @@ -0,0 +1,46 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops import operations as P | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| class NetSinh(nn.Cell): | |||
| def __init__(self): | |||
| super(NetSinh, self).__init__() | |||
| self.sinh = P.Sinh() | |||
| def construct(self, x): | |||
| return self.sinh(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_sinh(): | |||
| np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') | |||
| input_x = Tensor(np_array) | |||
| net = NetSinh() | |||
| output = net(input_x) | |||
| print(output) | |||
| expect = np.sinh(np_array) | |||
| assert np.allclose(output.asnumpy(), expect) | |||