From 396c18b92b35188e1ca8e127e72687bdd7d27cf4 Mon Sep 17 00:00:00 2001 From: kingfo Date: Fri, 19 Jun 2020 18:32:48 +0800 Subject: [PATCH] fix parameter cast issue --- mindspore/ops/operations/array_ops.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 2bb8a17a50..0b44e83b2f 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -28,6 +28,7 @@ from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype from ...common.tensor import Tensor +from ...common.parameter import Parameter from ..operations.math_ops import _infer_shape_reduce from .._utils import get_concat_offset from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op @@ -186,11 +187,15 @@ class Cast(PrimitiveWithInfer): self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) def check_elim(self, x, dtype): - if isinstance(x, (Tensor, numbers.Number)): + if isinstance(x, (Tensor, numbers.Number, Parameter)): if isinstance(x, Tensor) and x.dtype == dtype: return (True, x) if isinstance(x, numbers.Number): return (True, Tensor(x, dtype=dtype)) + if isinstance(x, Parameter): + data = x.default_input + if data.dtype == dtype: + return (True, x) return (False, None) raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})")