Browse Source

!1757 Complete vm ops for Cosh and Sinh

Merge pull request !1757 from lihongkang/lhk_master
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
a2ab975deb
7 changed files with 176 additions and 6 deletions
  1. +24
    -0
      mindspore/ops/_grad/grad_math_ops.py
  2. +2
    -0
      mindspore/ops/_op_impl/tbe/__init__.py
  3. +37
    -0
      mindspore/ops/_op_impl/tbe/cosh.py
  4. +37
    -0
      mindspore/ops/_op_impl/tbe/sinh.py
  5. +6
    -2
      mindspore/ops/operations/__init__.py
  6. +58
    -1
      mindspore/ops/operations/math_ops.py
  7. +12
    -3
      tests/ut/python/ops/test_ops.py

+ 24
- 0
mindspore/ops/_grad/grad_math_ops.py View File

@@ -793,6 +793,18 @@ def get_bprop_asinh(self):
return bprop


@bprop_getters.register(P.Sinh)
def get_bprop_sinh(self):
"""Grad definition for `Sinh` operation."""
cosh = P.Cosh()

def bprop(x, out, dout):
dx = cosh(x) * dout
return (dx,)

return bprop


@bprop_getters.register(P.Cos)
def get_bprop_cos(self):
"""Grad definition for `Cos` operation."""
@@ -830,6 +842,18 @@ def get_bprop_acosh(self):
return bprop


@bprop_getters.register(P.Cosh)
def get_bprop_cosh(self):
"""Grad definition for `Cosh` operation."""
sinh = P.Sinh()

def bprop(x, out, dout):
dx = sinh(x) * dout
return (dx,)

return bprop


@bprop_getters.register(P.Abs)
def get_bprop_abs(self):
"""Grad definition for `Abs` operation."""


+ 2
- 0
mindspore/ops/_op_impl/tbe/__init__.py View File

@@ -227,3 +227,5 @@ from .asinh_grad import _asinh_grad_tbe
from .atan import _atan_tbe
from .atan_grad import _atan_grad_tbe
from .atanh import _atanh_tbe
from .cosh import _cosh_tbe
from .sinh import _sinh_tbe

+ 37
- 0
mindspore/ops/_op_impl/tbe/cosh.py View File

@@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Cosh op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

cosh_op_info = TBERegOp("Cosh") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("cosh.so") \
.compute_cost(10) \
.kernel_name("cosh") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()


@op_info_register(cosh_op_info)
def _cosh_tbe():
"""Cosh TBE register"""
return

+ 37
- 0
mindspore/ops/_op_impl/tbe/sinh.py View File

@@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Sinh op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

sinh_op_info = TBERegOp("Sinh") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("sinh.so") \
.compute_cost(10) \
.kernel_name("sinh") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()


@op_info_register(sinh_op_info)
def _sinh_tbe():
"""Sinh TBE register"""
return

+ 6
- 2
mindspore/ops/operations/__init__.py View File

@@ -40,7 +40,8 @@ from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSumm
from .control_ops import ControlDepend, GeSwitch, Merge
from .inner_ops import ScalarCast

from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, BitwiseXor,
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr,
BitwiseXor,
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd,
Cos, Div, Equal, EqualCount, Exp, Erf, Erfc, Floor, FloorDiv, FloorMod, Acosh,
Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd,
@@ -50,7 +51,8 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
Reciprocal, CumSum,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e,
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh)
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh)

from .random_ops import (RandomChoiceWithMask)
from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D,
@@ -245,6 +247,7 @@ __all__ = [
'Asinh',
"PReLU",
"Cos",
"Cosh",
"ACos",
"Diag",
"DiagPart",
@@ -253,6 +256,7 @@ __all__ = [
'AssignAdd',
'AssignSub',
"Sin",
"Sinh",
"Asin",
"LSTM",
"Abs",


+ 58
- 1
mindspore/ops/operations/math_ops.py View File

@@ -1359,6 +1359,35 @@ class Acosh(PrimitiveWithInfer):
return x_dtype


class Cosh(PrimitiveWithInfer):
"""
Computes hyperbolic cosine of input element-wise.

Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.

Outputs:
Tensor, has the same shape as `input_x`.

Examples:
>>> cosh = P.Cosh()
>>> input_x = Tensor(np.array([0.24, 0.83, 0.31, 0.09]), mindspore.float32)
>>> output = cosh(input_x)
[1.0289385 1.364684 1.048436 1.4228927]
"""

@prim_attr_register
def __init__(self):
"""init Cosh"""

def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
return x_dtype


class Asinh(PrimitiveWithInfer):
"""
Compute inverse hyperbolic cosine of x element-wise.
@@ -1376,7 +1405,6 @@ class Asinh(PrimitiveWithInfer):
[-2.3212, 1.1976, 1.8184, 5.2983]
"""


@prim_attr_register
def __init__(self):
"""init Asinh"""
@@ -1389,6 +1417,35 @@ class Asinh(PrimitiveWithInfer):
return x_dtype


class Sinh(PrimitiveWithInfer):
"""
Computes hyperbolic sine of input element-wise.

Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.

Outputs:
Tensor, has the same shape as `input_x`.

Examples:
>>> sinh = P.Sinh()
>>> input_x = Tensor(np.array([0.62, 0.28, 0.43, 0.62]), mindspore.float32)
>>> output = sinh(input_x)
[0.6604918 0.28367308 0.44337422 0.6604918]
"""

@prim_attr_register
def __init__(self):
"""init Sinh"""

def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
return x_dtype


class _LogicBinaryOp(_BinaryOp):
"""
Define logic binary operators.


+ 12
- 3
tests/ut/python/ops/test_ops.py View File

@@ -128,7 +128,7 @@ class NetForFlattenComposed(nn.Cell):
self.flatten = P.Flatten()

def construct(self, x, y):
return self.flatten(x+x) + y
return self.flatten(x + x) + y


class ArgmaxNet(nn.Cell):
@@ -281,6 +281,7 @@ class ApplyRMSNet(nn.Cell):
out = self.apply_rms(self.var, self.ms, self.moment, self.lr, grad, self.rho, self.momentum, self.epsilon)
return out


test_case_math_ops = [
('BitwiseAnd', {
'block': P.BitwiseAnd(),
@@ -732,6 +733,14 @@ test_case_math_ops = [
'block': P.Atanh(),
'desc_inputs': [[2, 3]],
'desc_bprop': [[2, 3]]}),
('Cosh', {
'block': P.Cosh(),
'desc_inputs': [[3, 4, 5]],
'desc_bprop': [[3, 4, 5]]}),
('Sinh', {
'block': P.Sinh(),
'desc_inputs': [[3, 4, 5]],
'desc_bprop': [[3, 4, 5]]}),
]

test_case_nn_ops = [
@@ -1301,7 +1310,7 @@ test_case_array_ops = [
'desc_inputs': [(Tensor(np.array([1], np.float32)),
Tensor(np.array([1], np.float32)),
Tensor(np.array([1], np.float32)))],
'desc_bprop': [[3,]]}),
'desc_bprop': [[3, ]]}),
('Pack_0', {
'block': NetForPackInput(P.Pack()),
'desc_inputs': [[2, 2], [2, 2], [2, 2]],
@@ -1464,7 +1473,7 @@ test_case = functools.reduce(lambda x, y: x + y, test_case_lists)
test_exec_case = test_case

test_backward_exec_case = filter(lambda x: 'skip' not in x[1] or
'backward' not in x[1]['skip'], test_case)
'backward' not in x[1]['skip'], test_case)


@non_graph_engine


Loading…
Cancel
Save