Browse Source

!1926 Skip operations which are not supported by the backend in PyNative Mode

Merge pull request !1926 from BowenK/op_elim
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
cedfc7fac0
3 changed files with 40 additions and 2 deletions
  1. +16
    -0
      mindspore/ops/operations/array_ops.py
  2. +7
    -0
      mindspore/ops/operations/math_ops.py
  3. +17
    -2
      mindspore/ops/primitive.py

+ 16
- 0
mindspore/ops/operations/array_ops.py View File

@@ -185,6 +185,13 @@ class Cast(PrimitiveWithInfer):
"""init Cast""" """init Cast"""
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])


def check_elim(self, x, dtype):
if isinstance(x, Tensor):
if x.dtype() == dtype:
return (True, x)
return (False, None)
raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs))

def __infer__(self, x, t): def __infer__(self, x, t):
src_type = x['dtype'] src_type = x['dtype']
dst_type = t['value'] dst_type = t['value']
@@ -1345,6 +1352,15 @@ class Tile(PrimitiveWithInfer):
"""init Tile""" """init Tile"""
self.init_prim_io_names(inputs=['x', 'multiples'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'multiples'], outputs=['output'])


def check_elim(self, base_tensor, multiplier):
if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)):
raise ValueError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier))
def is_all_zeros(v_tuple):
return all(v == 1 for v in v_tuple)
if is_all_zeros(multiplier):
return (True, base_tensor)
return (False, None)

def __infer__(self, x, multiples): def __infer__(self, x, multiples):
multiples_v = multiples['value'] multiples_v = multiples['value']
x_shp = x['shape'] x_shp = x['shape']


+ 7
- 0
mindspore/ops/operations/math_ops.py View File

@@ -715,6 +715,13 @@ class AddN(PrimitiveWithInfer):
def __init__(self): def __init__(self):
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])


def check_elim(self, inputs):
if len(inputs) != 1:
return (False, None)
if isinstance(inputs[0], Tensor):
return (True, inputs[0])
raise TypeError("Expecting Tensor, got : {}".format(type(inputs[0])))

def infer_shape(self, inputs): def infer_shape(self, inputs):
cls_name = self.name cls_name = self.name
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)


+ 17
- 2
mindspore/ops/primitive.py View File

@@ -140,9 +140,24 @@ class Primitive(Primitive_):
return self.attrs[item] return self.attrs[item]
raise AttributeError(item) raise AttributeError(item)


def check_elim(self, *args):
"""
Check whether or not certain inputs should go into backend. Subclass in need should override this method.

Args:
Same as arguments of current Primitive

Returns:
A tuple of two elements, first element indicates whether or not we should filter out current arguments;
seconde element is the output in case where we should filter out the arguments.
"""
return (False, None)

def __call__(self, *args): def __call__(self, *args):
output = _run_op(self, self.name, args)
return output
should_elim, output = self.check_elim(*args)
if should_elim:
return output
return _run_op(self, self.name, args)


def __getstate__(self): def __getstate__(self):
return self.__dict__ return self.__dict__


Loading…
Cancel
Save