Merge pull request !7438 from huanghui/tensor-in-listtags/v1.1.0
| @@ -577,3 +577,12 @@ def tensor_setitem_by_ellipsis_with_tensor(data, index, value): | |||
| param2 = F.cast(value, data_dtype) | |||
| result = F.tensor_mul(param1, param2) | |||
| return result | |||
| def tensor_in_sequence(x, y): | |||
| """Assigns whether a sequence contains the given tensor""" | |||
| for i in y: | |||
| if isinstance(i, mstype.tensor) and x.shape == i.shape and x.dtype == i.dtype: | |||
| if F.equal(x, i).all(): | |||
| return const_utils.scalar_to_tensor(True) | |||
| return const_utils.scalar_to_tensor(False) | |||
| @@ -39,14 +39,17 @@ TENSOR_GETITEM = "tensor getitem" | |||
| SET_ITEM_BY_ONE_TENSOR = 0 | |||
| SET_ITEM_BY_TUPLE_OF_TENSOR = 1 | |||
| @constexpr | |||
| def raise_value_error(msg): | |||
| raise ValueError(msg) | |||
| @constexpr | |||
| def raise_index_error(msg): | |||
| raise IndexError(msg) | |||
| @constexpr | |||
| def raise_type_error(msg): | |||
| raise TypeError(msg) | |||
| @@ -704,7 +707,7 @@ def get_stride_info_from_slice(data_shape, slice_index): | |||
| def get_stride_info_from_integer(data_shape, number): | |||
| """Get stride info from a integer""" | |||
| begin_strides = [number] | |||
| end_strides = [number+1] | |||
| end_strides = [number + 1] | |||
| step_strides = [1] | |||
| for end in data_shape[1:]: | |||
| begin_strides.append(0) | |||
| @@ -720,7 +723,7 @@ def get_slice_stride(dim_size, index_slice): | |||
| stop_default = dim_size | |||
| if step < 0: | |||
| start_default = -1 | |||
| stop_default = -(dim_size+1) | |||
| stop_default = -(dim_size + 1) | |||
| start = start_default if index_slice.start is None else index_slice.start | |||
| stop = stop_default if index_slice.stop is None else index_slice.stop | |||
| return start, stop, step | |||
| @@ -775,3 +778,9 @@ def mstype_eq(x, y): | |||
| if x == y: | |||
| return True | |||
| return False | |||
| @constexpr | |||
| def scalar_to_tensor(x): | |||
| """Convert a scalar to a tensor""" | |||
| return Tensor(x) | |||
| @@ -16,6 +16,7 @@ | |||
| """Implementation for internal polymorphism `in` operations.""" | |||
| from . import _constexpr_utils as const_utils | |||
| from . import _compile_utils as compile_utils | |||
| from ... import functional as F | |||
| from ...composite import base | |||
| @@ -99,3 +100,33 @@ def _str_in_dict(x, y): | |||
| bool, if x in y return true, x not in y return false. | |||
| """ | |||
| return F.in_dict(x, y) | |||
| @in_.register("Tensor", "List") | |||
| def _tensor_in_list(x, y): | |||
| """ | |||
| Determine if a tensor in a list. | |||
| Args: | |||
| x: Tensor | |||
| y: List | |||
| Returns: | |||
| bool, if x in y return true, x not in y return false. | |||
| """ | |||
| return compile_utils.tensor_in_sequence(x, y) | |||
| @in_.register("Tensor", "Tuple") | |||
| def _tensor_in_tuple(x, y): | |||
| """ | |||
| Determine if a tensor in a tuple. | |||
| Args: | |||
| x: Tensor | |||
| y: Tuple | |||
| Returns: | |||
| bool, if x in y return true, x not in y return false. | |||
| """ | |||
| return compile_utils.tensor_in_sequence(x, y) | |||
| @@ -16,6 +16,7 @@ | |||
| """Implementation for internal polymorphism `not in` operations.""" | |||
| from . import _constexpr_utils as const_utils | |||
| from . import _compile_utils as compile_utils | |||
| from ... import functional as F | |||
| from ...composite import base | |||
| @@ -99,3 +100,33 @@ def _str_not_in_dict(x, y): | |||
| bool, if x not in y return true, x in y return false. | |||
| """ | |||
| return F.not_in_dict(x, y) | |||
| @not_in_.register("Tensor", "List") | |||
| def _tensor_not_in_list(x, y): | |||
| """ | |||
| Determine if a tensor not in a list. | |||
| Args: | |||
| x: Tensor | |||
| y: List | |||
| Returns: | |||
| bool, if x not in y return true, x in y return false. | |||
| """ | |||
| return not compile_utils.tensor_in_sequence(x, y) | |||
| @not_in_.register("Tensor", "Tuple") | |||
| def _tensor_not_in_tuple(x, y): | |||
| """ | |||
| Determine if a tensor not in a tuple. | |||
| Args: | |||
| x: Tensor | |||
| y: Tuple | |||
| Returns: | |||
| bool, if x not in y return true, x in y return false. | |||
| """ | |||
| return not compile_utils.tensor_in_sequence(x, y) | |||
| @@ -30,7 +30,6 @@ from tests.mindspore_test_framework.pipeline.forward.compile_forward \ | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| grad_all = C.GradOperation(get_all=True) | |||
| @@ -258,6 +257,34 @@ class AxisListDefaultNet(nn.Cell): | |||
| return self.reduce_sum(x) | |||
| class TensorInList(nn.Cell): | |||
| def __init__(self): | |||
| super(TensorInList, self).__init__() | |||
| self.t1 = Tensor(1, mstype.float32) | |||
| self.t2 = Tensor(2, mstype.float32) | |||
| def construct(self, x): | |||
| ret = x | |||
| list_ = [1, [2, 3], "str", self.t1, self.t2, x] | |||
| if x in list_: | |||
| ret = x + x | |||
| return ret | |||
| class TensorNotInList(nn.Cell): | |||
| def __init__(self): | |||
| super(TensorNotInList, self).__init__() | |||
| self.t1 = Tensor(1, mstype.float32) | |||
| self.t2 = Tensor(2, mstype.float32) | |||
| def construct(self, x): | |||
| ret = x | |||
| list_ = [self.t2, x] | |||
| if self.t1 not in list_: | |||
| ret = x + x | |||
| return ret | |||
| test_case_ops = [ | |||
| ('ListOperate', { | |||
| 'block': ListOperate(), | |||
| @@ -275,6 +302,12 @@ test_case_ops = [ | |||
| ('InList', { | |||
| 'block': InListNet(), | |||
| 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), | |||
| ('TensorInList', { | |||
| 'block': TensorInList(), | |||
| 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), | |||
| ('TensorNotInList', { | |||
| 'block': TensorNotInList(), | |||
| 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), | |||
| ] | |||
| test_case_lists = [test_case_ops] | |||
| @@ -53,7 +53,7 @@ class NestTupleGraphNet(nn.Cell): | |||
| class InTupleNet(nn.Cell): | |||
| def __init__(self,): | |||
| def __init__(self): | |||
| super(InTupleNet, self).__init__() | |||
| self.tuple_ = (1, 2, 3, 4, 5, "ok") | |||
| @@ -66,6 +66,34 @@ class InTupleNet(nn.Cell): | |||
| return ret | |||
| class TensorInTuple(nn.Cell): | |||
| def __init__(self): | |||
| super(TensorInTuple, self).__init__() | |||
| self.t1 = Tensor(1, mstype.float32) | |||
| self.t2 = Tensor(2, mstype.float32) | |||
| self.tuple_ = (self.t1, self.t2) | |||
| def construct(self, x): | |||
| ret = x | |||
| if self.t1 in self.tuple_: | |||
| ret = x + x | |||
| return ret | |||
| class TensorNotInTuple(nn.Cell): | |||
| def __init__(self): | |||
| super(TensorNotInTuple, self).__init__() | |||
| self.t1 = Tensor(1, mstype.float32) | |||
| self.t2 = Tensor(2, mstype.float32) | |||
| self.tuple_ = (self.t1, self.t2) | |||
| def construct(self, x): | |||
| ret = x | |||
| if self.t1 not in self.tuple_: | |||
| ret = x + x | |||
| return ret | |||
| test_case_ops = [ | |||
| ('TupleGraph', { | |||
| 'block': TupleGraphNet(), | |||
| @@ -75,7 +103,13 @@ test_case_ops = [ | |||
| 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), | |||
| ('InTuple', { | |||
| 'block': InTupleNet(), | |||
| 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}) | |||
| 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), | |||
| ('TensorInTuple', { | |||
| 'block': TensorInTuple(), | |||
| 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), | |||
| ('TensorNotInTuple', { | |||
| 'block': TensorNotInTuple(), | |||
| 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), | |||
| ] | |||
| test_case_lists = [test_case_ops] | |||