diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 8de4108435..6eebde3a84 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -1656,6 +1656,8 @@ class IsFinite(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): + validator.check_subclass("x", x_dtype, mstype.tensor, self.name) + validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name) return mstype.bool_ class FloatStatus(PrimitiveWithInfer):