Browse Source

add vm ops: asin asinh asingrad asinhgrad

tags/v0.5.0-beta
fangzehua 5 years ago
parent
commit
c0b8a90105
10 changed files with 317 additions and 23 deletions
  1. +22
    -0
      mindspore/ops/_grad/grad_math_ops.py
  2. +4
    -0
      mindspore/ops/_op_impl/tbe/__init__.py
  3. +37
    -0
      mindspore/ops/_op_impl/tbe/asin.py
  4. +43
    -0
      mindspore/ops/_op_impl/tbe/asin_grad.py
  5. +37
    -0
      mindspore/ops/_op_impl/tbe/asinh.py
  6. +43
    -0
      mindspore/ops/_op_impl/tbe/asinh_grad.py
  7. +4
    -1
      mindspore/ops/operations/__init__.py
  8. +39
    -0
      mindspore/ops/operations/_grad_ops.py
  9. +80
    -22
      mindspore/ops/operations/math_ops.py
  10. +8
    -0
      tests/ut/python/ops/test_ops.py

+ 22
- 0
mindspore/ops/_grad/grad_math_ops.py View File

@@ -770,6 +770,28 @@ def get_bprop_sin(self):
return bprop


@bprop_getters.register(P.Asin)
def get_bprop_asin(self):
"""Grad definition for `Asin` operation."""
input_grad = G.AsinGrad()

def bprop(x, out, dout):
dx = input_grad(x, dout)
return (dx,)
return bprop


@bprop_getters.register(P.Asinh)
def get_bprop_asinh(self):
"""Grad definition for `Asinh` operation."""
input_grad = G.AsinhGrad()

def bprop(x, out, dout):
dx = input_grad(out, dout)
return (dx,)
return bprop


@bprop_getters.register(P.Cos)
def get_bprop_cos(self):
"""Grad definition for `Cos` operation."""


+ 4
- 0
mindspore/ops/_op_impl/tbe/__init__.py View File

@@ -208,3 +208,7 @@ from .bitwise_xor import bitwise_xor_op_info
from .reduce_all import _reduce_all_tbe
from .sparse_apply_adagrad import _sparse_apply_adagrad_tbe
from .unsorted_segment_min import _unsorted_segment_min_tbe
from .asin import _asin_tbe
from .asin_grad import _asin_grad_tbe
from .asinh import _asinh_tbe
from .asinh_grad import _asinh_grad_tbe

+ 37
- 0
mindspore/ops/_op_impl/tbe/asin.py View File

@@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Asin op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

asin_op_info = TBERegOp("Asin") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("asin.so") \
.compute_cost(10) \
.kernel_name("asin") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()


@op_info_register(asin_op_info)
def _asin_tbe():
"""Asin TBE register"""
return

+ 43
- 0
mindspore/ops/_op_impl/tbe/asin_grad.py View File

@@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""AsinGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

asin_grad_op_info = TBERegOp("AsinGrad") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("asin_grad.so") \
.compute_cost(10) \
.kernel_name("asin_grad") \
.partial_flag(True) \
.input(0, "y", None, "required", "all") \
.input(1, "dy", None, "required", "all") \
.output(0, "z", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()


@op_info_register(asin_grad_op_info)
def _asin_grad_tbe():
"""AsinGrad TBE register"""
return

+ 37
- 0
mindspore/ops/_op_impl/tbe/asinh.py View File

@@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Asin op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

asinh_op_info = TBERegOp("Asinh") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("asinh.so") \
.compute_cost(10) \
.kernel_name("asinh") \
.partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()


@op_info_register(asinh_op_info)
def _asinh_tbe():
"""Asinh TBE register"""
return

+ 43
- 0
mindspore/ops/_op_impl/tbe/asinh_grad.py View File

@@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""AsinhGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType

asinh_grad_op_info = TBERegOp("AsinhGrad") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("asinh_grad.so") \
.compute_cost(10) \
.kernel_name("asinh_grad") \
.partial_flag(True) \
.input(0, "y", False, "required", "all") \
.input(1, "dy", False, "required", "all") \
.output(0, "z", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()


@op_info_register(asinh_grad_op_info)
def _asinh_grad_tbe():
"""AsinhGrad TBE register"""
return

+ 4
- 1
mindspore/ops/operations/__init__.py View File

@@ -39,7 +39,8 @@ from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSumm
TensorSummary, HistogramSummary, Print)
from .control_ops import ControlDepend, GeSwitch, Merge
from .inner_ops import ScalarCast
from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, BitwiseXor,

from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, BitwiseXor,
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd,
Cos, Div, Equal, EqualCount, Exp, Erf, Erfc, Floor, FloorDiv, FloorMod, Acosh,
Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd,
@@ -239,6 +240,7 @@ __all__ = [
'FloorDiv',
'FloorMod',
'Acosh',
'Asinh',
"PReLU",
"Cos",
"ACos",
@@ -249,6 +251,7 @@ __all__ = [
'AssignAdd',
'AssignSub',
"Sin",
"Asin",
"LSTM",
"Abs",
"BinaryCrossEntropy",


+ 39
- 0
mindspore/ops/operations/_grad_ops.py View File

@@ -76,6 +76,45 @@ class AcoshGrad(PrimitiveWithInfer):
return x


class AsinGrad(PrimitiveWithInfer):
"""
Computes AsinGrad of input element-wise.

Returns:
Tensor, has the same type as input.
"""

@prim_attr_register
def __init__(self):
"""Init AsinGrad"""

def infer_shape(self, x, dout):
validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
return x

def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x


class AsinhGrad(PrimitiveWithInfer):
"""Performs grad of Asinh operation."""

@prim_attr_register
def __init__(self):
"""init AsinhGrad"""

def infer_shape(self, x, dout):
validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name)
return x

def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout}
validator.check_tensor_type_same(args, mstype.number_type, self.name)
return x


class BatchNormGrad(PrimitiveWithInfer):
"""Performs grad of BatchNorm operation."""



+ 80
- 22
mindspore/ops/operations/math_ops.py View File

@@ -1336,8 +1336,7 @@ class Acosh(PrimitiveWithInfer):
Compute inverse hyperbolic cosine of x element-wise.

Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`,
and the data type of 'input_x' is number, the element in 'input_x' should be greater than or equal to 1.
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.

Outputs:
Tensor, has the same shape as `input_x`.
@@ -1352,12 +1351,42 @@ class Acosh(PrimitiveWithInfer):
def __init__(self):
"""init Acosh"""

def infer_shape(self, x):
return x
def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
return x_dtype

def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x

class Asinh(PrimitiveWithInfer):
"""
Compute inverse hyperbolic cosine of x element-wise.

Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.

Outputs:
Tensor, has the same shape as `input_x`.

Examples:
>>> asinh = P.Asinh()
>>> input_x = Tensor(np.array([-5.0, 1.5, 3.0, 100.0]), mindspore.float32)
>>> output = asinh(input_x)
[-2.3212, 1.1976, 1.8184, 5.2983]
"""


@prim_attr_register
def __init__(self):
"""init Asinh"""

def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
return x_dtype


class _LogicBinaryOp(_BinaryOp):
@@ -1926,12 +1955,12 @@ class Cos(PrimitiveWithInfer):
def __init__(self):
"""init Cos"""

def infer_shape(self, x):
return x
def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
return x_dtype


class ACos(PrimitiveWithInfer):
@@ -1954,12 +1983,12 @@ class ACos(PrimitiveWithInfer):
def __init__(self):
"""init ACos"""

def infer_shape(self, x):
return x
def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
return x_dtype


class Sin(PrimitiveWithInfer):
@@ -1982,12 +2011,41 @@ class Sin(PrimitiveWithInfer):
def __init__(self):
"""Init Sin."""

def infer_shape(self, x):
return x
def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
return x_dtype


class Asin(PrimitiveWithInfer):
"""
Computes arccosine of input element-wise.

Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.

Outputs:
Tensor, has the same shape as `input_x`.

Examples:
>>> asin = P.Asin()
>>> input_x = Tensor(np.array([0.74, 0.04, 0.30, 0.56]), mindspore.float32)
>>> output = asin(input_x)
[0.8331, 0.0400, 0.3047, 0.5944]
"""

@prim_attr_register
def __init__(self):
"""init Asin"""

def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
return x_dtype


class NMSWithMask(PrimitiveWithInfer):


+ 8
- 0
tests/ut/python/ops/test_ops.py View File

@@ -369,6 +369,14 @@ test_case_math_ops = [
'block': P.Sin(),
'desc_inputs': [[2, 3]],
'desc_bprop': [[2, 3]]}),
('Asin', {
'block': P.Asin(),
'desc_inputs': [[2, 3]],
'desc_bprop': [[2, 3]]}),
('Asinh', {
'block': P.Asinh(),
'desc_inputs': [[3, 4, 5]],
'desc_bprop': [[3, 4, 5]]}),
('Reciprocal', {
'block': P.Reciprocal(),
'desc_inputs': [[2, 3, 3, 5]],


Loading…
Cancel
Save