|
|
|
@@ -308,7 +308,8 @@ class Elu(PrimitiveWithInfer): |
|
|
|
The data type of input tensor should be float. |
|
|
|
|
|
|
|
Args: |
|
|
|
alpha (float): The coefficient of negative factor whose type is float. Default: 1.0. |
|
|
|
alpha (float): The coefficient of negative factor whose type is float, |
|
|
|
only support '1.0' currently. Default: 1.0. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input_x** (Tensor) - The input tensor whose data type should be float. |
|
|
|
@@ -328,6 +329,7 @@ class Elu(PrimitiveWithInfer): |
|
|
|
def __init__(self, alpha=1.0): |
|
|
|
"""Init Elu""" |
|
|
|
validator.check_value_type("alpha", alpha, [float], self.name) |
|
|
|
validator.check_number("alpha", alpha, 1.0, Rel.EQ, self.name) |
|
|
|
|
|
|
|
def infer_shape(self, input_x): |
|
|
|
return input_x |
|
|
|
|