|
|
@@ -1656,6 +1656,8 @@ class IsFinite(PrimitiveWithInfer): |
|
|
return x_shape |
|
|
return x_shape |
|
|
|
|
|
|
|
|
def infer_dtype(self, x_dtype): |
|
|
def infer_dtype(self, x_dtype): |
|
|
|
|
|
validator.check_subclass("x", x_dtype, mstype.tensor, self.name) |
|
|
|
|
|
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type + (mstype.bool_,), self.name) |
|
|
return mstype.bool_ |
|
|
return mstype.bool_ |
|
|
|
|
|
|
|
|
class FloatStatus(PrimitiveWithInfer): |
|
|
class FloatStatus(PrimitiveWithInfer): |
|
|
|