Browse Source

!26420 [feat] [assistant] [I48O90, I48O4P] Add AsinhGrad,Asinh

Merge pull request !26420 from 桂宁馨/Asinh
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
908beb6b2f
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
15 changed files with 196 additions and 69 deletions
  1. +11
    -0
      mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_self_cpu_kernel.cc
  2. +7
    -8
      mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_self_cpu_kernel.h
  3. +16
    -1
      mindspore/ccsrc/plugin/device/cpu/kernel/eltwise_grad_cpu_kernel.cc
  4. +17
    -0
      mindspore/ccsrc/plugin/device/cpu/kernel/eltwise_grad_cpu_kernel.h
  5. +4
    -2
      mindspore/core/base/core_ops.h
  6. +18
    -21
      mindspore/core/ops/asinh.cc
  7. +15
    -4
      mindspore/core/ops/asinh.h
  8. +16
    -24
      mindspore/core/ops/grad/asinh_grad.cc
  9. +13
    -5
      mindspore/core/ops/grad/asinh_grad.h
  10. +2
    -0
      mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py
  11. +34
    -0
      mindspore/python/mindspore/ops/_op_impl/aicpu/asinh.py
  12. +35
    -0
      mindspore/python/mindspore/ops/_op_impl/aicpu/asinh_grad.py
  13. +2
    -1
      mindspore/python/mindspore/ops/operations/_grad_ops.py
  14. +2
    -3
      mindspore/python/mindspore/ops/operations/math_ops.py
  15. +4
    -0
      tests/ut/python/ops/test_ops.py

+ 11
- 0
mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_self_cpu_kernel.cc View File

@@ -233,6 +233,16 @@ void Cosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_);
}

template <typename T>
void ComplexAsinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(asinh(in[i]));
}
};
ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_);
}

template <typename T>
void Asinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
@@ -391,6 +401,7 @@ void ArithmeticSelfCpuKernelMod::LaunchKernelComplex(const std::vector<AddressPt
std::function<void(ArithmeticSelfCpuKernelMod *, const T *, T *, size_t)>>
arithmeticSelfFuncMap{{prim::kPrimSquare->name(), Square<T>},
{prim::kPrimAcosh->name(), ComplexAcosh<T>},
{prim::kPrimAsinh->name(), ComplexAsinh<T>},
{prim::kPrimNeg->name(), Neg<T>}};
const auto func_pair = arithmeticSelfFuncMap.find(kernel_name_);
if (arithmeticSelfFuncMap.find(kernel_name_) == arithmeticSelfFuncMap.end()) {


+ 7
- 8
mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_self_cpu_kernel.h View File

@@ -21,15 +21,14 @@
#include <memory>
#include <set>
#include <vector>

using complex64 = std::complex<float>;
using complex128 = std::complex<double>;

#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"

namespace mindspore {
namespace kernel {
using complex64 = std::complex<float>;
using complex128 = std::complex<double>;

class ArithmeticSelfCpuKernelMod : public NativeCpuKernelMod {
public:
ArithmeticSelfCpuKernelMod() = default;
@@ -164,6 +163,10 @@ MS_REG_CPU_KERNEL(Acosh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutp
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Acosh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Asinh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Asinh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
@@ -211,10 +214,6 @@ MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddO
IdentityCpuKernelMod, float16);
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
IdentityCpuKernelMod, bool);
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
IdentityCpuKernelMod, complex64);
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
IdentityCpuKernelMod, complex128);
} // namespace kernel
} // namespace mindspore



+ 16
- 1
mindspore/ccsrc/plugin/device/cpu/kernel/eltwise_grad_cpu_kernel.cc View File

@@ -190,6 +190,20 @@ void EltWiseGradCpuKernelMod<T>::AsinhGrad(const T *input1, const T *input2, T *
}
}

template <typename T>
void EltWiseGradCpuKernelMod<T>::ComplexAsinhGrad(const T *input1, const T *input2, T *out, size_t start,
size_t end) const {
for (size_t i = start; i < end; i++) {
T dividend = input2[i];
T divisor = std::conj(cosh(input1[i]));
if (divisor == static_cast<T>(0)) {
out[i] = std::numeric_limits<T>::quiet_NaN();
continue;
}
out[i] = dividend / divisor;
}
}

template <typename T>
void EltWiseGradCpuKernelMod<T>::AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const {
for (size_t i = start; i < end; i++) {
@@ -291,7 +305,8 @@ void EltWiseGradCpuKernelMod<T>::InitComputeFunc() {
if constexpr ((std::is_same_v<T, complex64>) || (std::is_same_v<T, complex128>)) {
static const std::map<std::string,
std::function<void(EltWiseGradCpuKernelMod *, const T *, const T *, T *, size_t, size_t)>>
elt_map{{prim::kPrimAcoshGrad->name(), &EltWiseGradCpuKernelMod<T>::ComplexAcoshGrad}};
elt_map{{prim::kPrimAcoshGrad->name(), &EltWiseGradCpuKernelMod<T>::ComplexAcoshGrad},
{prim::kPrimAsinhGrad->name(), &EltWiseGradCpuKernelMod<T>::ComplexAsinhGrad}};
if (elt_map.find(kernel_name_) == elt_map.end()) {
MS_LOG(EXCEPTION) << "EltWiseGradCpuKernelMod does not support " << kernel_name_;
}


+ 17
- 0
mindspore/ccsrc/plugin/device/cpu/kernel/eltwise_grad_cpu_kernel.h View File

@@ -55,6 +55,7 @@ class EltWiseGradCpuKernelMod : public NativeCpuKernelMod {
void ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void AsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void ComplexAsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void ComplexAcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
void SoftplusGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) const;
@@ -125,6 +126,22 @@ MS_REG_CPU_KERNEL_T(
AsinhGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
AsinhGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
EltWiseGradCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(AsinhGrad,
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
EltWiseGradCpuKernelMod, complex64);
MS_REG_CPU_KERNEL_T(AsinhGrad,
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
EltWiseGradCpuKernelMod, complex128);
MS_REG_CPU_KERNEL_T(
AcoshGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),


+ 4
- 2
mindspore/core/base/core_ops.h View File

@@ -74,6 +74,8 @@ constexpr auto kMatrixInverse = "MatrixInverse";
constexpr auto kMatrixDeterminant = "MatrixDeterminant";
constexpr auto kLogMatrixDeterminant = "LogMatrixDeterminant";
constexpr auto kCos = "Cos";
constexpr auto kAsinh = "Asinh";
constexpr auto kAsinhGrad = "AsinhGrad";
constexpr auto kAbs = "Abs";
constexpr auto kTrunc = "Trunc";
constexpr auto kLpNorm = "LpNorm";
@@ -384,7 +386,7 @@ MS_CORE_API inline const PrimitivePtr kPrimAsin = std::make_shared<Primitive>("A
MS_CORE_API inline const PrimitivePtr kPrimSinh = std::make_shared<Primitive>("Sinh");
MS_CORE_API inline const PrimitivePtr kPrimCosh = std::make_shared<Primitive>("Cosh");
MS_CORE_API inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>(kTanh);
MS_CORE_API inline const PrimitivePtr kPrimAsinh = std::make_shared<Primitive>("Asinh");
MS_CORE_API inline const PrimitivePtr kPrimAsinh = std::make_shared<Primitive>(kAsinh);
MS_CORE_API inline const PrimitivePtr kPrimAcosh = std::make_shared<Primitive>(kAcosh);
MS_CORE_API inline const PrimitivePtr kPrimAtanh = std::make_shared<Primitive>("Atanh");
MS_CORE_API inline const PrimitivePtr kPrimApplyGradientDescent = std::make_shared<Primitive>("ApplyGradientDescent");
@@ -700,7 +702,7 @@ MS_CORE_API inline const PrimitivePtr kPrimACos = std::make_shared<Primitive>(kA
MS_CORE_API inline const PrimitivePtr kPrimAsinGrad = std::make_shared<Primitive>("AsinGrad");
MS_CORE_API inline const PrimitivePtr kPrimACosGrad = std::make_shared<Primitive>(kACosGrad);
MS_CORE_API inline const PrimitivePtr kPrimAtanGrad = std::make_shared<Primitive>("AtanGrad");
MS_CORE_API inline const PrimitivePtr kPrimAsinhGrad = std::make_shared<Primitive>("AsinhGrad");
MS_CORE_API inline const PrimitivePtr kPrimAsinhGrad = std::make_shared<Primitive>(kAsinhGrad);
MS_CORE_API inline const PrimitivePtr kPrimAcoshGrad = std::make_shared<Primitive>("AcoshGrad");
MS_CORE_API inline const PrimitivePtr kPrimFloorMod = std::make_shared<Primitive>("FloorMod");
MS_CORE_API inline const PrimitivePtr kPrimCdist = std::make_shared<Primitive>(kCdist);


+ 18
- 21
mindspore/core/ops/asinh.cc View File

@@ -14,49 +14,46 @@
* limitations under the License.
*/

#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>

#include "ops/asinh.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include "abstract/param_validator.h"

namespace mindspore {
namespace ops {
namespace {
const size_t InputNum = 1;
const int64_t MaxDim = 8;

abstract::ShapePtr AsinhInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto x = input_args[0]->BuildShape();
auto x = input_args[kInputIndex0]->BuildShape();
MS_EXCEPTION_IF_NULL(x);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("The dimension of Asinh input", SizeToLong(in_shape.size()), kLessThan,
MaxDim, prim_name);
auto shape_element = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
return shape_element;
}

TypePtr AsinhInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
MS_EXCEPTION_IF_NULL(input_args[0]);
auto x_type = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, common_valid_types, prim_name);
return x_type;
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
auto x_type = input_args[kInputIndex0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
return input_args[kInputIndex0]->BuildType();
}
} // namespace

AbstractBasePtr AsinhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = AsinhInferType(primitive, input_args);
auto infer_shape = AsinhInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, InputNum, prim_name);
auto types = AsinhInferType(primitive, input_args);
auto shapes = AsinhInferShape(primitive, input_args);
return abstract::MakeAbstract(shapes, types);
}

REGISTER_PRIMITIVE_EVAL_IMPL(Asinh, prim::kPrimAsinh, AsinhInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

+ 15
- 4
mindspore/core/ops/asinh.h View File

@@ -19,23 +19,34 @@

#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>

#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"

namespace mindspore {
namespace ops {
constexpr auto kNameAsinh = "Asinh";
class Asinh : public PrimitiveC {
/// \brief Computes arcsinh of input tensors element-wise.
/// Refer to Python API @ref mindspore.ops.Asinh for more details.
class MS_CORE_API Asinh : public PrimitiveC {
public:
Asinh() : PrimitiveC(kNameAsinh) { InitIOName({"x"}, {"output"}); }
/// \brief Constructor.
Asinh() : PrimitiveC(kNameAsinh) { InitIOName({"x"}, {"y"}); }
/// \brief Destructor.
~Asinh() = default;

MS_DECLARE_PARENT(Asinh, PrimitiveC);
void Init() {}
};

AbstractBasePtr AsinhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);

using PrimAsinhPtr = std::shared_ptr<Asinh>;
} // namespace ops
} // namespace mindspore



+ 16
- 24
mindspore/core/ops/grad/asinh_grad.cc View File

@@ -15,24 +15,16 @@
*/

#include "ops/grad/asinh_grad.h"
#include <algorithm>
#include <set>
#include "abstract/param_validator.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"

namespace mindspore {
namespace ops {
namespace {
const size_t InputNum = 2;

abstract::ShapePtr AsinhGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x = input_args[0]->BuildShape();
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto x = input_args[kInputIndex0]->BuildShape();
MS_EXCEPTION_IF_NULL(x);
auto shape_element = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
@@ -40,24 +32,24 @@ abstract::ShapePtr AsinhGradInferShape(const PrimitivePtr &primitive, const std:
}

TypePtr AsinhGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto x_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(x_type);
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_types, prim_name);
return x_type;
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
std::map<std::string, TypePtr> types;
(void)types.emplace("y", input_args[kInputIndex0]->BuildType());
(void)types.emplace("dy", input_args[kInputIndex1]->BuildType());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
return input_args[kInputIndex0]->BuildType();
}
} // namespace

AbstractBasePtr AsinhGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto type = AsinhGradInferType(primitive, input_args);
auto shape = AsinhGradInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, InputNum, prim_name);
auto types = AsinhGradInferType(primitive, input_args);
auto shapes = AsinhGradInferShape(primitive, input_args);
return abstract::MakeAbstract(shapes, types);
}

REGISTER_PRIMITIVE_EVAL_IMPL(AsinhGrad, prim::kPrimAsinhGrad, AsinhGradInfer, nullptr, true);


+ 13
- 5
mindspore/core/ops/grad/asinh_grad.h View File

@@ -16,24 +16,32 @@

#ifndef MINDSPORE_CORE_OPS_ASINH_GRAD_H_
#define MINDSPORE_CORE_OPS_ASINH_GRAD_H_

#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include <set>
#include <string>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
#include "ops/op_utils.h"

namespace mindspore {
namespace ops {
constexpr auto kNameAsinhGrad = "AsinhGrad";

class AsinhGrad : public PrimitiveC {
public:
AsinhGrad() : PrimitiveC(kNameAsinhGrad) { InitIOName({"x"}, {"output"}); }
AsinhGrad() : PrimitiveC(kNameAsinhGrad) { InitIOName({"y", "dy"}, {"z"}); }
~AsinhGrad() = default;

MS_DECLARE_PARENT(AsinhGrad, PrimitiveC);
};

AbstractBasePtr AsinhGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimAsinhGradPtr = std::shared_ptr<AsinhGrad>;
} // namespace ops
} // namespace mindspore



+ 2
- 0
mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py View File

@@ -93,6 +93,8 @@ from .trans_data import _trans_data_aicpu
from .stack_push_pop import _stack_init_aicpu
from .stack_push_pop import _stack_push_aicpu
from .stack_push_pop import _stack_pop_aicpu
from .asinh import _asinh_aicpu
from .asinh_grad import _asinh_grad_aicpu
from .stack_push_pop import _stack_destroy_aicpu
from .ctc_greedy_decoder import _ctc_greedy_decoder_aicpu
from .resize_bilinear import _resize_bilinear_aicpu


+ 34
- 0
mindspore/python/mindspore/ops/_op_impl/aicpu/asinh.py View File

@@ -0,0 +1,34 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Asinh op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType

asinh_op_info = AiCPURegOp("Asinh") \
.fusion_type("ELEMWISE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.C128_Default) \
.get_op_info()


@op_info_register(asinh_op_info)
def _asinh_aicpu():
"""Asinh AiCPU register"""
return

+ 35
- 0
mindspore/python/mindspore/ops/_op_impl/aicpu/asinh_grad.py View File

@@ -0,0 +1,35 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""AsinhGrad op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType

asinh_grad_op_info = AiCPURegOp("AsinhGrad") \
.fusion_type("ELEMWISE") \
.input(0, "y", "required") \
.input(1, "dy", "required") \
.output(0, "z", "required") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default) \
.get_op_info()


@op_info_register(asinh_grad_op_info)
def _asinh_grad_aicpu():
"""AsinhGrad AiCPU register"""
return

+ 2
- 1
mindspore/python/mindspore/ops/operations/_grad_ops.py View File

@@ -65,12 +65,13 @@ class AsinGrad(Primitive):
"""Initialize AsinGrad"""


class AsinhGrad(PrimitiveWithInfer):
class AsinhGrad(Primitive):
"""Performs grad of Asinh operation."""

@prim_attr_register
def __init__(self):
"""Initialize AsinhGrad"""
self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])


class ReciprocalGrad(Primitive):


+ 2
- 3
mindspore/python/mindspore/ops/operations/math_ops.py View File

@@ -3473,7 +3473,6 @@ class Asinh(Primitive):
Inputs:
- **x** (Tensor) - The shape of tensor is
:math:`(N,*)` where :math:`*` means, any number of additional dimensions, its rank should be less than 8.
The data type should be one of the following types: float16, float32.

Outputs:
Tensor, has the same shape and type as `x`.
@@ -3489,13 +3488,13 @@ class Asinh(Primitive):
>>> x = Tensor(np.array([-5.0, 1.5, 3.0, 100.0]), mindspore.float32)
>>> output = asinh(x)
>>> print(output)
[-2.3124385 1.1947632 1.8184465 5.298342 ]
[-2.3124382 1.1947632 1.8184465 5.298342 ]
"""

@prim_attr_register
def __init__(self):
"""Initialize Asinh"""
self.init_prim_io_names(inputs=['x'], outputs=['y'])

class Sinh(Primitive):
r"""


+ 4
- 0
tests/ut/python/ops/test_ops.py View File

@@ -1268,6 +1268,10 @@ test_case_math_ops = [
'block': P.Asinh(),
'desc_inputs': [[3, 4, 5]],
'desc_bprop': [[3, 4, 5]]}),
('AsinhGrad', {
'block': G.AsinhGrad(),
'desc_inputs': [[2, 3], [2, 3]],
'skip': ['backward']}),
('Tan', {
'block': P.Tan(),
'desc_inputs': [[2, 3]],


Loading…
Cancel
Save