|
|
|
@@ -185,6 +185,13 @@ class Cast(PrimitiveWithInfer): |
|
|
|
"""init Cast""" |
|
|
|
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) |
|
|
|
|
|
|
|
def check_elim(self, x, dtype): |
|
|
|
if isinstance(x, Tensor): |
|
|
|
if x.dtype() == dtype: |
|
|
|
return (True, x) |
|
|
|
return (False, None) |
|
|
|
raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs)) |
|
|
|
|
|
|
|
def __infer__(self, x, t): |
|
|
|
src_type = x['dtype'] |
|
|
|
dst_type = t['value'] |
|
|
|
@@ -1345,6 +1352,15 @@ class Tile(PrimitiveWithInfer): |
|
|
|
"""init Tile""" |
|
|
|
self.init_prim_io_names(inputs=['x', 'multiples'], outputs=['output']) |
|
|
|
|
|
|
|
def check_elim(self, base_tensor, multiplier): |
|
|
|
if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)): |
|
|
|
raise ValueError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier)) |
|
|
|
def is_all_zeros(v_tuple): |
|
|
|
return all(v == 1 for v in v_tuple) |
|
|
|
if is_all_zeros(multiplier): |
|
|
|
return (True, base_tensor) |
|
|
|
return (False, None) |
|
|
|
|
|
|
|
def __infer__(self, x, multiples): |
|
|
|
multiples_v = multiples['value'] |
|
|
|
x_shp = x['shape'] |
|
|
|
|