|
|
|
@@ -186,11 +186,13 @@ class Cast(PrimitiveWithInfer): |
|
|
|
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) |
|
|
|
|
|
|
|
def check_elim(self, x, dtype): |
|
|
|
if isinstance(x, Tensor): |
|
|
|
if x.dtype == dtype: |
|
|
|
if isinstance(x, (Tensor, numbers.Number)): |
|
|
|
if isinstance(x, Tensor) and x.dtype == dtype: |
|
|
|
return (True, x) |
|
|
|
if isinstance(x, numbers.Number): |
|
|
|
return (True, Tensor(x, dtype=dtype)) |
|
|
|
return (False, None) |
|
|
|
raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs)) |
|
|
|
raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})") |
|
|
|
|
|
|
|
def __infer__(self, x, t): |
|
|
|
src_type = x['dtype'] |
|
|
|
|