Merge pull request !1926 from BowenK/op_elimtags/v0.5.0-beta
| @@ -185,6 +185,13 @@ class Cast(PrimitiveWithInfer): | |||||
| """init Cast""" | """init Cast""" | ||||
| self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) | self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) | ||||
| def check_elim(self, x, dtype): | |||||
| if isinstance(x, Tensor): | |||||
| if x.dtype() == dtype: | |||||
| return (True, x) | |||||
| return (False, None) | |||||
| raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs)) | |||||
| def __infer__(self, x, t): | def __infer__(self, x, t): | ||||
| src_type = x['dtype'] | src_type = x['dtype'] | ||||
| dst_type = t['value'] | dst_type = t['value'] | ||||
| @@ -1345,6 +1352,15 @@ class Tile(PrimitiveWithInfer): | |||||
| """init Tile""" | """init Tile""" | ||||
| self.init_prim_io_names(inputs=['x', 'multiples'], outputs=['output']) | self.init_prim_io_names(inputs=['x', 'multiples'], outputs=['output']) | ||||
| def check_elim(self, base_tensor, multiplier): | |||||
| if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)): | |||||
| raise ValueError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier)) | |||||
| def is_all_zeros(v_tuple): | |||||
| return all(v == 1 for v in v_tuple) | |||||
| if is_all_zeros(multiplier): | |||||
| return (True, base_tensor) | |||||
| return (False, None) | |||||
| def __infer__(self, x, multiples): | def __infer__(self, x, multiples): | ||||
| multiples_v = multiples['value'] | multiples_v = multiples['value'] | ||||
| x_shp = x['shape'] | x_shp = x['shape'] | ||||
| @@ -715,6 +715,13 @@ class AddN(PrimitiveWithInfer): | |||||
| def __init__(self): | def __init__(self): | ||||
| self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) | self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) | ||||
| def check_elim(self, inputs): | |||||
| if len(inputs) != 1: | |||||
| return (False, None) | |||||
| if isinstance(inputs[0], Tensor): | |||||
| return (True, inputs[0]) | |||||
| raise TypeError("Expecting Tensor, got : {}".format(type(inputs[0]))) | |||||
| def infer_shape(self, inputs): | def infer_shape(self, inputs): | ||||
| cls_name = self.name | cls_name = self.name | ||||
| validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) | validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) | ||||
| @@ -140,9 +140,24 @@ class Primitive(Primitive_): | |||||
| return self.attrs[item] | return self.attrs[item] | ||||
| raise AttributeError(item) | raise AttributeError(item) | ||||
| def check_elim(self, *args): | |||||
| """ | |||||
| Check whether or not certain inputs should go into backend. Subclass in need should override this method. | |||||
| Args: | |||||
| Same as arguments of current Primitive | |||||
| Returns: | |||||
| A tuple of two elements, first element indicates whether or not we should filter out current arguments; | |||||
| seconde element is the output in case where we should filter out the arguments. | |||||
| """ | |||||
| return (False, None) | |||||
| def __call__(self, *args): | def __call__(self, *args): | ||||
| output = _run_op(self, self.name, args) | |||||
| return output | |||||
| should_elim, output = self.check_elim(*args) | |||||
| if should_elim: | |||||
| return output | |||||
| return _run_op(self, self.name, args) | |||||
| def __getstate__(self): | def __getstate__(self): | ||||
| return self.__dict__ | return self.__dict__ | ||||