diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 1b1366d92e..da8f74930e 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -320,16 +320,16 @@ class Cast(PrimitiveWithInfer): def check_elim(self, x, dtype): if isinstance(x, (Tensor, numbers.Number, Parameter)): + if isinstance(x, Parameter): + data = x.data + if data.dtype == dtype: + return (True, x) if isinstance(x, Tensor) and x.dtype == dtype: x = Tensor(x) x.set_cast_dtype() return (True, x) if isinstance(x, numbers.Number): return (True, Tensor(x, dtype=dtype)) - if isinstance(x, Parameter): - data = x.data - if data.dtype == dtype: - return (True, x) return (False, None) def __infer__(self, x, t):