|
|
|
@@ -311,14 +311,16 @@ class Cast(PrimitiveWithInfer): |
|
|
|
|
|
|
|
def check_elim(self, x, dtype): |
|
|
|
if isinstance(x, (Tensor, numbers.Number, Parameter)): |
|
|
|
if isinstance(x, Tensor) and x.dtype == dtype: |
|
|
|
return (True, x) |
|
|
|
if isinstance(x, numbers.Number): |
|
|
|
return (True, Tensor(x, dtype=dtype)) |
|
|
|
if isinstance(x, Parameter): |
|
|
|
data = x.data |
|
|
|
if data.dtype == dtype: |
|
|
|
return (True, x) |
|
|
|
if isinstance(x, Tensor) and x.dtype == dtype: |
|
|
|
x = Tensor(x) |
|
|
|
x.set_cast_dtype() |
|
|
|
return (True, x) |
|
|
|
if isinstance(x, numbers.Number): |
|
|
|
return (True, Tensor(x, dtype=dtype)) |
|
|
|
return (False, None) |
|
|
|
|
|
|
|
def __infer__(self, x, t): |
|
|
|
|