Browse Source

!13772 fix the bug that the data type of float16 and float32 of SeLU is only supported and update the documentation of Mish operator.

From: @wangshuide2020
Reviewed-by: @wuxuejian,@liangchenghui
Signed-off-by: @liangchenghui
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
0f833b595a
1 changed files with 12 additions and 4 deletions
  1. +12
    -4
      mindspore/ops/operations/nn_ops.py

+ 12
- 4
mindspore/ops/operations/nn_ops.py View File

@@ -372,7 +372,7 @@ class ReLU(PrimitiveWithCheck):

class Mish(PrimitiveWithInfer):
r"""
Computes MISH of input tensors element-wise.
Computes MISH(A Self Regularized Non-Monotonic Neural Activation Function) of input tensors element-wise.

The function is shown as follows:

@@ -380,6 +380,9 @@ class Mish(PrimitiveWithInfer):

\text{output} = x * \tan(\log(1 + \exp(\text{x})))

See more details in `A Self Regularized Non-Monotonic Neural Activation Function
<https://arxiv.org/abs/1908.08681>`_.

Inputs:
- **x** (Tensor) - The input tensor. Only support float16 and float32.

@@ -390,7 +393,7 @@ class Mish(PrimitiveWithInfer):
``Ascend``

Raise:
TypeError: If num_features data type not float16 and float32 Tensor.
TypeError: If dtype of `x` is neither float16 nor float32.

Examples:
>>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
@@ -428,6 +431,11 @@ class SeLU(PrimitiveWithInfer):
\text{alpha} * (\exp(x_i) - 1), &\text{otherwise.}
\end{cases}

where :math:`alpha` and :math:`scale` are pre-defined constants(:math:`alpha=1.67326324`
and :math:`scale=1.05070098`).

See more details in `Self-Normalizing Neural Networks <https://arxiv.org/abs/1706.02515>`_.

Inputs:
- **input_x** (Tensor) - The input tensor.

@@ -438,7 +446,7 @@ class SeLU(PrimitiveWithInfer):
``Ascend``

Raise:
TypeError: If num_features data type not int8, int32, float16 and float32 Tensor.
TypeError: If dtype of `input_x` is neither float16 nor float32.

Examples:
>>> input_x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
@@ -458,7 +466,7 @@ class SeLU(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype):
valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensor_dtype_valid('x', x_dtype, valid_dtypes, self.name)
return x_dtype



Loading…
Cancel
Save