|
|
|
@@ -28,6 +28,7 @@ from ..._checkparam import Validator as validator |
|
|
|
from ..._checkparam import Rel |
|
|
|
from ...common import dtype as mstype |
|
|
|
from ...common.tensor import Tensor |
|
|
|
from ...common.parameter import Parameter |
|
|
|
from ..operations.math_ops import _infer_shape_reduce |
|
|
|
from .._utils import get_concat_offset |
|
|
|
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op |
|
|
|
@@ -186,11 +187,15 @@ class Cast(PrimitiveWithInfer): |
|
|
|
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) |
|
|
|
|
|
|
|
def check_elim(self, x, dtype): |
|
|
|
if isinstance(x, (Tensor, numbers.Number)): |
|
|
|
if isinstance(x, (Tensor, numbers.Number, Parameter)): |
|
|
|
if isinstance(x, Tensor) and x.dtype == dtype: |
|
|
|
return (True, x) |
|
|
|
if isinstance(x, numbers.Number): |
|
|
|
return (True, Tensor(x, dtype=dtype)) |
|
|
|
if isinstance(x, Parameter): |
|
|
|
data = x.default_input |
|
|
|
if data.dtype == dtype: |
|
|
|
return (True, x) |
|
|
|
return (False, None) |
|
|
|
raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})") |
|
|
|
|
|
|
|
|