| @@ -28,6 +28,7 @@ from ..._checkparam import Validator as validator | |||||
| from ..._checkparam import Rel | from ..._checkparam import Rel | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.tensor import Tensor | from ...common.tensor import Tensor | ||||
| from ...common.parameter import Parameter | |||||
| from ..operations.math_ops import _infer_shape_reduce | from ..operations.math_ops import _infer_shape_reduce | ||||
| from .._utils import get_concat_offset | from .._utils import get_concat_offset | ||||
| from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op | from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op | ||||
| @@ -186,11 +187,15 @@ class Cast(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) | self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) | ||||
| def check_elim(self, x, dtype): | def check_elim(self, x, dtype): | ||||
| if isinstance(x, (Tensor, numbers.Number)): | |||||
| if isinstance(x, (Tensor, numbers.Number, Parameter)): | |||||
| if isinstance(x, Tensor) and x.dtype == dtype: | if isinstance(x, Tensor) and x.dtype == dtype: | ||||
| return (True, x) | return (True, x) | ||||
| if isinstance(x, numbers.Number): | if isinstance(x, numbers.Number): | ||||
| return (True, Tensor(x, dtype=dtype)) | return (True, Tensor(x, dtype=dtype)) | ||||
| if isinstance(x, Parameter): | |||||
| data = x.default_input | |||||
| if data.dtype == dtype: | |||||
| return (True, x) | |||||
| return (False, None) | return (False, None) | ||||
| raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})") | raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})") | ||||