diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index fc92bfc8fb..34547fd755 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -1712,6 +1712,7 @@ class FloatStatus(PrimitiveWithInfer): return [1] def infer_dtype(self, x_dtype): + validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32, mstype.float16], self.name) return x_dtype class NPUAllocFloatStatus(PrimitiveWithInfer):