Browse Source

!1804 add tensor compare & len & constexpr feature

Merge pull request !1804 from wangqiuliang/add-tensor-compare-len-consexpr-feature
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
5d397d8404
5 changed files with 51 additions and 6 deletions
  1. +5
    -0
      mindspore/ccsrc/pynative/pynative_execute.cc
  2. +28
    -6
      mindspore/common/tensor.py
  3. +8
    -0
      mindspore/ops/functional.py
  4. +1
    -0
      mindspore/ops/primitive.py
  5. +9
    -0
      tests/ut/python/pynative_mode/test_parse_method.py

+ 5
- 0
mindspore/ccsrc/pynative/pynative_execute.cc View File

@@ -531,6 +531,11 @@ py::tuple RunOp(const py::args &args) {
value_ret[0] = output["value"];
return value_ret;
}
if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) {
py::tuple value_ret(1);
value_ret[0] = "";
return value_ret;
}
}
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name;
mindspore::parse::python_adapter::set_python_env_flag(true);


+ 28
- 6
mindspore/common/tensor.py View File

@@ -71,19 +71,18 @@ class Tensor(Tensor_):
return str(self.__str__())

def __add__(self, other):
check_type('tensor input_data', other, (Tensor, float, int))
out = tensor_operator_registry.get('__add__')(self, other)
return out

def __eq__(self, other):
if not isinstance(other, Tensor):
return False
return Tensor(np.array(self.asnumpy() == other.asnumpy()))
return tensor_operator_registry.get('__eq__')(self, other)

def __ne__(self, other):
if not isinstance(other, Tensor):
return True
return Tensor(np.array(self.asnumpy() != other.asnumpy()))
return tensor_operator_registry.get('__ne__')(self, other)

def __hash__(self):
return hash(id(self))
@@ -93,7 +92,8 @@ class Tensor(Tensor_):
return out

def __neg__(self):
return Tensor(-self.asnumpy())
out = tensor_operator_registry.get('__neg__')(self)
return out

def __iadd__(self, other):
out = self.__add__(other)
@@ -120,7 +120,7 @@ class Tensor(Tensor_):
return out

def __sub__(self, other):
out = self.__add__(-other)
out = tensor_operator_registry.get('__sub__')(self, other)
return out

def __isub__(self, other):
@@ -128,9 +128,31 @@ class Tensor(Tensor_):
return out

def __rsub__(self, other):
out = tensor_operator_registry.get('__add__')(other, Tensor(-self.asnumpy()))
out = tensor_operator_registry.get('__sub__')(other, self)
return out

def __lt__(self, other):
out = tensor_operator_registry.get('__lt__')(self, other)
return out

def __le__(self, other):
out = tensor_operator_registry.get('__le__')(self, other)
return out

def __gt__(self, other):
out = tensor_operator_registry.get('__gt__')(self, other)
return out

def __ge__(self, other):
out = tensor_operator_registry.get('__ge__')(self, other)
return out

def __len__(self):
out = tensor_operator_registry.get('__shape__')(self)
if not out:
return 1
return out[0]

def __str__(self):
if self.dtype() == mstype.type_none:
return "Unknown Tensor type!"


+ 8
- 0
mindspore/ops/functional.py View File

@@ -151,7 +151,15 @@ shape_mul = Primitive("shape_mul")
stop_gradient = Primitive("stop_gradient")

tensor_operator_registry.register('__add__', tensor_add)
tensor_operator_registry.register('__sub__', tensor_sub)
tensor_operator_registry.register('__mul__', tensor_mul)
tensor_operator_registry.register('__div__', tensor_div)
#ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal)
tensor_operator_registry.register('__neg__', neg_tensor)
tensor_operator_registry.register('__lt__', tensor_lt)
tensor_operator_registry.register('__le__', tensor_le)
tensor_operator_registry.register('__gt__', tensor_gt)
tensor_operator_registry.register('__ge__', tensor_ge)
tensor_operator_registry.register('__shape__', shape)

+ 1
- 0
mindspore/ops/primitive.py View File

@@ -310,6 +310,7 @@ def constexpr(fn=None, get_instance=True, name=None):
def __init__(self):
op_name = name if name else fn.__name__
PrimitiveWithInfer.__init__(self, op_name)
self.const_value = True

def infer_value(self, *args):
return fn(*args)


+ 9
- 0
tests/ut/python/pynative_mode/test_parse_method.py View File

@@ -29,6 +29,7 @@ from mindspore._extends.parse.standard_method import ms_len
from mindspore.common.api import ms_function
from mindspore.common.tensor import Tensor
from mindspore.ops.composite import core
from mindspore.ops.primitive import constexpr
from ..ut_filter import non_graph_engine


@@ -417,3 +418,11 @@ def test_range():
""" test_range """
res = range_spec(10, 10)
return res

def test_expr():
""" test const expr """
a = (1, 2)
@constexpr
def tuple_len(x):
assert len(x) == 2
tuple_len(a)

Loading…
Cancel
Save