Browse Source

fix the validation of Softmax, Tanh, Elu operators.

tags/v1.2.0-rc1
wangshuide2020 4 years ago
parent
commit
8da6d65222
2 changed files with 5 additions and 5 deletions
  1. +1
    -1
      mindspore/nn/layer/activation.py
  2. +4
    -4
      mindspore/ops/operations/nn_ops.py

+ 1
- 1
mindspore/nn/layer/activation.py View File

@@ -175,7 +175,7 @@ class ELU(Cell):
ValueError: If `alpha` is not equal to 1.0.

Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> input_x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float32)


+ 4
- 4
mindspore/ops/operations/nn_ops.py View File

@@ -172,7 +172,7 @@ class Softmax(PrimitiveWithInfer):
return logits

def infer_dtype(self, logits):
validator.check_tensor_dtype_valid("logits", logits, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("logits", logits, mstype.float_type, self.name)
return logits


@@ -603,7 +603,7 @@ class Elu(PrimitiveWithInfer):
return input_x

def infer_dtype(self, input_x):
validator.check_tensor_dtype_valid('input_x', input_x, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid('input_x', input_x, mstype.float_type, self.name)
return input_x


@@ -761,7 +761,7 @@ class Tanh(PrimitiveWithInfer):
TypeError: If dtype of `input_x` is neither float16 nor float32.

Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32)
@@ -779,7 +779,7 @@ class Tanh(PrimitiveWithInfer):
return input_x

def infer_dtype(self, input_x):
validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name)
validator.check_tensor_dtype_valid("input_x", input_x, mstype.float_type, self.name)
return input_x




Loading…
Cancel
Save