|
|
|
@@ -172,7 +172,7 @@ class Softmax(PrimitiveWithInfer): |
|
|
|
return logits |
|
|
|
|
|
|
|
def infer_dtype(self, logits): |
|
|
|
validator.check_tensor_dtype_valid("logits", logits, (mstype.float16, mstype.float32), self.name) |
|
|
|
validator.check_tensor_dtype_valid("logits", logits, mstype.float_type, self.name) |
|
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
@@ -603,7 +603,7 @@ class Elu(PrimitiveWithInfer): |
|
|
|
return input_x |
|
|
|
|
|
|
|
def infer_dtype(self, input_x): |
|
|
|
validator.check_tensor_dtype_valid('input_x', input_x, (mstype.float16, mstype.float32), self.name) |
|
|
|
validator.check_tensor_dtype_valid('input_x', input_x, mstype.float_type, self.name) |
|
|
|
return input_x |
|
|
|
|
|
|
|
|
|
|
|
@@ -761,7 +761,7 @@ class Tanh(PrimitiveWithInfer): |
|
|
|
TypeError: If dtype of `input_x` is neither float16 nor float32. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``Ascend`` ``GPU`` |
|
|
|
``Ascend`` ``GPU`` ``CPU`` |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32) |
|
|
|
@@ -779,7 +779,7 @@ class Tanh(PrimitiveWithInfer): |
|
|
|
return input_x |
|
|
|
|
|
|
|
def infer_dtype(self, input_x): |
|
|
|
validator.check_tensor_dtype_valid("input_x", input_x, (mstype.float16, mstype.float32), self.name) |
|
|
|
validator.check_tensor_dtype_valid("input_x", input_x, mstype.float_type, self.name) |
|
|
|
return input_x |
|
|
|
|
|
|
|
|
|
|
|
|