Browse Source

!26841 [feat] [assistant] [I48O70] [I48O4W]Add Sinh and Cosh operators

Merge pull request !26841 from 桂宁馨/Sinh_Cosh
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
81ff4f4440
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 124 additions and 26 deletions
  1. +21
    -0
      mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_self_cpu_kernel.cc
  2. +8
    -0
      mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_self_cpu_kernel.h
  3. +13
    -13
      mindspore/core/ops/cosh.cc
  4. +12
    -13
      mindspore/core/ops/sinh.cc
  5. +2
    -0
      mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py
  6. +34
    -0
      mindspore/python/mindspore/ops/_op_impl/aicpu/cosh.py
  7. +34
    -0
      mindspore/python/mindspore/ops/_op_impl/aicpu/sinh.py

+ 21
- 0
mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_self_cpu_kernel.cc View File

@@ -193,6 +193,16 @@ void ComplexSin(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t
ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_);
}

template <typename T>
void ComplexSinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(sinh(in[i]));
}
};
ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_);
}

template <typename T>
void ComplexCos(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
@@ -203,6 +213,16 @@ void ComplexCos(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t
ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_);
}

template <typename T>
void ComplexCosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(cosh(in[i]));
}
};
ParallelLaunchAutoSearch(task, size, content, &content->parallel_search_info_);
}

template <typename T>
void Tan(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
@@ -399,6 +419,7 @@ void ArithmeticSelfCpuKernelMod::LaunchKernelComplex(const std::vector<AddressPt
std::function<void(ArithmeticSelfCpuKernelMod *, const T *, T *, size_t)>>
arithmeticSelfFuncMap{{prim::kPrimSquare->name(), Square<T>}, {prim::kPrimAcosh->name(), ComplexAcosh<T>},
{prim::kPrimAsinh->name(), ComplexAsinh<T>}, {prim::kPrimNeg->name(), Neg<T>},
{prim::kPrimSinh->name(), ComplexSinh<T>}, {prim::kPrimCosh->name(), ComplexCosh<T>},
{prim::kPrimSin->name(), ComplexSin<T>}, {prim::kPrimCos->name(), ComplexCos<T>}};
const auto func_pair = arithmeticSelfFuncMap.find(kernel_name_);
if (arithmeticSelfFuncMap.find(kernel_name_) == arithmeticSelfFuncMap.end()) {


+ 8
- 0
mindspore/ccsrc/plugin/device/cpu/kernel/arithmetic_self_cpu_kernel.h View File

@@ -151,6 +151,14 @@ MS_REG_CPU_KERNEL(Cosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputA
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Cosh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Sinh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Cosh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Sinh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Cosh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),


+ 13
- 13
mindspore/core/ops/cosh.cc View File

@@ -13,16 +13,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ops/cosh.h"
#include <string>
#include <algorithm>
#include <memory>

#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include <map>
#include <string>

namespace mindspore {
namespace ops {
@@ -30,9 +25,12 @@ namespace {
abstract::ShapePtr CoshInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto x = input_args[0]->BuildShape();
const int64_t max_dim = 8;
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("The dimension of Cosh input", SizeToLong(in_shape.size()), kLessThan,
max_dim, prim_name);
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
MS_EXCEPTION_IF_NULL(x);
auto shape_element = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
@@ -40,16 +38,18 @@ abstract::ShapePtr CoshInferShape(const PrimitivePtr &primitive, const std::vect
}

TypePtr CoshInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
auto x_dtype = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, common_valid_types_with_complex, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, valid_types, prim->name());
return x_dtype;
}
} // namespace

AbstractBasePtr CoshInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputNum = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name());
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = CoshInferType(primitive, input_args);
auto infer_shape = CoshInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);


+ 12
- 13
mindspore/core/ops/sinh.cc View File

@@ -13,16 +13,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/sinh.h"

#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "ops/sinh.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
#include <map>
#include <string>

namespace mindspore {
namespace ops {
@@ -30,9 +25,12 @@ namespace {
abstract::ShapePtr SinhInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, 1, prim_name);
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto x = input_args[0]->BuildShape();
const int64_t max_dim = 8;
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("The dimension of Sinh input", SizeToLong(in_shape.size()), kLessThan,
max_dim, prim_name);
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
MS_EXCEPTION_IF_NULL(x);
auto shape_element = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
@@ -40,8 +38,9 @@ abstract::ShapePtr SinhInferShape(const PrimitivePtr &primitive, const std::vect
}

TypePtr SinhInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
auto x_dtype = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, common_valid_types_with_complex, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, valid_types, prim->name());
return x_dtype;
}
} // namespace
@@ -49,8 +48,8 @@ TypePtr SinhInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
AbstractBasePtr SinhInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kInputNum = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, primitive->name());
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = SinhInferType(primitive, input_args);
auto infer_shape = SinhInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);


+ 2
- 0
mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py View File

@@ -48,6 +48,8 @@ from .reshape import _reshape_aicpu
from .flatten import _flatten_aicpu
from .sin import _sin_aicpu
from .cos import _cos_aicpu
from .sinh import _sinh_aicpu
from .cosh import _cosh_aicpu
from .squeeze import _squeeze_aicpu
from .acos import _acos_aicpu
from .acos_grad import _acos_grad_aicpu


+ 34
- 0
mindspore/python/mindspore/ops/_op_impl/aicpu/cosh.py View File

@@ -0,0 +1,34 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Cosh op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType

cosh_op_info = AiCPURegOp("Cosh") \
.fusion_type("ELEMWISE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.C128_Default) \
.get_op_info()


@op_info_register(cosh_op_info)
def _cosh_aicpu():
"""Cosh AiCPU register"""
return

+ 34
- 0
mindspore/python/mindspore/ops/_op_impl/aicpu/sinh.py View File

@@ -0,0 +1,34 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Sinh op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType

sinh_op_info = AiCPURegOp("Sinh") \
.fusion_type("ELEMWISE") \
.input(0, "x", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.C128_Default) \
.get_op_info()


@op_info_register(sinh_op_info)
def _sinh_aicpu():
"""Sinh AiCPU register"""
return

Loading…
Cancel
Save