Browse Source

!7438 Support judge tensor in(or not in) a list(or tuple)

Merge pull request !7438 from huanghui/tensor-in-list
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
7f343e404a
6 changed files with 152 additions and 5 deletions
  1. +9
    -0
      mindspore/ops/composite/multitype_ops/_compile_utils.py
  2. +11
    -2
      mindspore/ops/composite/multitype_ops/_constexpr_utils.py
  3. +31
    -0
      mindspore/ops/composite/multitype_ops/in_impl.py
  4. +31
    -0
      mindspore/ops/composite/multitype_ops/not_in_impl.py
  5. +34
    -1
      tests/ut/python/dtype/test_list.py
  6. +36
    -2
      tests/ut/python/dtype/test_tuple.py

+ 9
- 0
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -577,3 +577,12 @@ def tensor_setitem_by_ellipsis_with_tensor(data, index, value):
param2 = F.cast(value, data_dtype)
result = F.tensor_mul(param1, param2)
return result


def tensor_in_sequence(x, y):
"""Assigns whether a sequence contains the given tensor"""
for i in y:
if isinstance(i, mstype.tensor) and x.shape == i.shape and x.dtype == i.dtype:
if F.equal(x, i).all():
return const_utils.scalar_to_tensor(True)
return const_utils.scalar_to_tensor(False)

+ 11
- 2
mindspore/ops/composite/multitype_ops/_constexpr_utils.py View File

@@ -39,14 +39,17 @@ TENSOR_GETITEM = "tensor getitem"
SET_ITEM_BY_ONE_TENSOR = 0
SET_ITEM_BY_TUPLE_OF_TENSOR = 1


@constexpr
def raise_value_error(msg):
raise ValueError(msg)


@constexpr
def raise_index_error(msg):
raise IndexError(msg)


@constexpr
def raise_type_error(msg):
raise TypeError(msg)
@@ -704,7 +707,7 @@ def get_stride_info_from_slice(data_shape, slice_index):
def get_stride_info_from_integer(data_shape, number):
"""Get stride info from a integer"""
begin_strides = [number]
end_strides = [number+1]
end_strides = [number + 1]
step_strides = [1]
for end in data_shape[1:]:
begin_strides.append(0)
@@ -720,7 +723,7 @@ def get_slice_stride(dim_size, index_slice):
stop_default = dim_size
if step < 0:
start_default = -1
stop_default = -(dim_size+1)
stop_default = -(dim_size + 1)
start = start_default if index_slice.start is None else index_slice.start
stop = stop_default if index_slice.stop is None else index_slice.stop
return start, stop, step
@@ -775,3 +778,9 @@ def mstype_eq(x, y):
if x == y:
return True
return False


@constexpr
def scalar_to_tensor(x):
"""Convert a scalar to a tensor"""
return Tensor(x)

+ 31
- 0
mindspore/ops/composite/multitype_ops/in_impl.py View File

@@ -16,6 +16,7 @@
"""Implementation for internal polymorphism `in` operations."""

from . import _constexpr_utils as const_utils
from . import _compile_utils as compile_utils
from ... import functional as F
from ...composite import base

@@ -99,3 +100,33 @@ def _str_in_dict(x, y):
bool, if x in y return true, x not in y return false.
"""
return F.in_dict(x, y)


@in_.register("Tensor", "List")
def _tensor_in_list(x, y):
"""
Determine if a tensor in a list.

Args:
x: Tensor
y: List

Returns:
bool, if x in y return true, x not in y return false.
"""
return compile_utils.tensor_in_sequence(x, y)


@in_.register("Tensor", "Tuple")
def _tensor_in_tuple(x, y):
"""
Determine if a tensor in a tuple.

Args:
x: Tensor
y: Tuple

Returns:
bool, if x in y return true, x not in y return false.
"""
return compile_utils.tensor_in_sequence(x, y)

+ 31
- 0
mindspore/ops/composite/multitype_ops/not_in_impl.py View File

@@ -16,6 +16,7 @@
"""Implementation for internal polymorphism `not in` operations."""

from . import _constexpr_utils as const_utils
from . import _compile_utils as compile_utils
from ... import functional as F
from ...composite import base

@@ -99,3 +100,33 @@ def _str_not_in_dict(x, y):
bool, if x not in y return true, x in y return false.
"""
return F.not_in_dict(x, y)


@not_in_.register("Tensor", "List")
def _tensor_not_in_list(x, y):
"""
Determine if a tensor not in a list.

Args:
x: Tensor
y: List

Returns:
bool, if x not in y return true, x in y return false.
"""
return not compile_utils.tensor_in_sequence(x, y)


@not_in_.register("Tensor", "Tuple")
def _tensor_not_in_tuple(x, y):
"""
Determine if a tensor not in a tuple.

Args:
x: Tensor
y: Tuple

Returns:
bool, if x not in y return true, x in y return false.
"""
return not compile_utils.tensor_in_sequence(x, y)

+ 34
- 1
tests/ut/python/dtype/test_list.py View File

@@ -30,7 +30,6 @@ from tests.mindspore_test_framework.pipeline.forward.compile_forward \

context.set_context(mode=context.GRAPH_MODE)


grad_all = C.GradOperation(get_all=True)


@@ -258,6 +257,34 @@ class AxisListDefaultNet(nn.Cell):
return self.reduce_sum(x)


class TensorInList(nn.Cell):
def __init__(self):
super(TensorInList, self).__init__()
self.t1 = Tensor(1, mstype.float32)
self.t2 = Tensor(2, mstype.float32)

def construct(self, x):
ret = x
list_ = [1, [2, 3], "str", self.t1, self.t2, x]
if x in list_:
ret = x + x
return ret


class TensorNotInList(nn.Cell):
def __init__(self):
super(TensorNotInList, self).__init__()
self.t1 = Tensor(1, mstype.float32)
self.t2 = Tensor(2, mstype.float32)

def construct(self, x):
ret = x
list_ = [self.t2, x]
if self.t1 not in list_:
ret = x + x
return ret


test_case_ops = [
('ListOperate', {
'block': ListOperate(),
@@ -275,6 +302,12 @@ test_case_ops = [
('InList', {
'block': InListNet(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
('TensorInList', {
'block': TensorInList(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
('TensorNotInList', {
'block': TensorNotInList(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}),
]

test_case_lists = [test_case_ops]


+ 36
- 2
tests/ut/python/dtype/test_tuple.py View File

@@ -53,7 +53,7 @@ class NestTupleGraphNet(nn.Cell):


class InTupleNet(nn.Cell):
def __init__(self,):
def __init__(self):
super(InTupleNet, self).__init__()
self.tuple_ = (1, 2, 3, 4, 5, "ok")

@@ -66,6 +66,34 @@ class InTupleNet(nn.Cell):
return ret


class TensorInTuple(nn.Cell):
def __init__(self):
super(TensorInTuple, self).__init__()
self.t1 = Tensor(1, mstype.float32)
self.t2 = Tensor(2, mstype.float32)
self.tuple_ = (self.t1, self.t2)

def construct(self, x):
ret = x
if self.t1 in self.tuple_:
ret = x + x
return ret


class TensorNotInTuple(nn.Cell):
def __init__(self):
super(TensorNotInTuple, self).__init__()
self.t1 = Tensor(1, mstype.float32)
self.t2 = Tensor(2, mstype.float32)
self.tuple_ = (self.t1, self.t2)

def construct(self, x):
ret = x
if self.t1 not in self.tuple_:
ret = x + x
return ret


test_case_ops = [
('TupleGraph', {
'block': TupleGraphNet(),
@@ -75,7 +103,13 @@ test_case_ops = [
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
('InTuple', {
'block': InTupleNet(),
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]})
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
('TensorInTuple', {
'block': TensorInTuple(),
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
('TensorNotInTuple', {
'block': TensorNotInTuple(),
'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}),
]

test_case_lists = [test_case_ops]


Loading…
Cancel
Save