Merge pull request !7438 from huanghui/tensor-in-listtags/v1.1.0
| @@ -577,3 +577,12 @@ def tensor_setitem_by_ellipsis_with_tensor(data, index, value): | |||||
| param2 = F.cast(value, data_dtype) | param2 = F.cast(value, data_dtype) | ||||
| result = F.tensor_mul(param1, param2) | result = F.tensor_mul(param1, param2) | ||||
| return result | return result | ||||
| def tensor_in_sequence(x, y): | |||||
| """Assigns whether a sequence contains the given tensor""" | |||||
| for i in y: | |||||
| if isinstance(i, mstype.tensor) and x.shape == i.shape and x.dtype == i.dtype: | |||||
| if F.equal(x, i).all(): | |||||
| return const_utils.scalar_to_tensor(True) | |||||
| return const_utils.scalar_to_tensor(False) | |||||
| @@ -39,14 +39,17 @@ TENSOR_GETITEM = "tensor getitem" | |||||
| SET_ITEM_BY_ONE_TENSOR = 0 | SET_ITEM_BY_ONE_TENSOR = 0 | ||||
| SET_ITEM_BY_TUPLE_OF_TENSOR = 1 | SET_ITEM_BY_TUPLE_OF_TENSOR = 1 | ||||
| @constexpr | @constexpr | ||||
| def raise_value_error(msg): | def raise_value_error(msg): | ||||
| raise ValueError(msg) | raise ValueError(msg) | ||||
| @constexpr | @constexpr | ||||
| def raise_index_error(msg): | def raise_index_error(msg): | ||||
| raise IndexError(msg) | raise IndexError(msg) | ||||
| @constexpr | @constexpr | ||||
| def raise_type_error(msg): | def raise_type_error(msg): | ||||
| raise TypeError(msg) | raise TypeError(msg) | ||||
| @@ -704,7 +707,7 @@ def get_stride_info_from_slice(data_shape, slice_index): | |||||
| def get_stride_info_from_integer(data_shape, number): | def get_stride_info_from_integer(data_shape, number): | ||||
| """Get stride info from a integer""" | """Get stride info from a integer""" | ||||
| begin_strides = [number] | begin_strides = [number] | ||||
| end_strides = [number+1] | |||||
| end_strides = [number + 1] | |||||
| step_strides = [1] | step_strides = [1] | ||||
| for end in data_shape[1:]: | for end in data_shape[1:]: | ||||
| begin_strides.append(0) | begin_strides.append(0) | ||||
| @@ -720,7 +723,7 @@ def get_slice_stride(dim_size, index_slice): | |||||
| stop_default = dim_size | stop_default = dim_size | ||||
| if step < 0: | if step < 0: | ||||
| start_default = -1 | start_default = -1 | ||||
| stop_default = -(dim_size+1) | |||||
| stop_default = -(dim_size + 1) | |||||
| start = start_default if index_slice.start is None else index_slice.start | start = start_default if index_slice.start is None else index_slice.start | ||||
| stop = stop_default if index_slice.stop is None else index_slice.stop | stop = stop_default if index_slice.stop is None else index_slice.stop | ||||
| return start, stop, step | return start, stop, step | ||||
| @@ -775,3 +778,9 @@ def mstype_eq(x, y): | |||||
| if x == y: | if x == y: | ||||
| return True | return True | ||||
| return False | return False | ||||
| @constexpr | |||||
| def scalar_to_tensor(x): | |||||
| """Convert a scalar to a tensor""" | |||||
| return Tensor(x) | |||||
| @@ -16,6 +16,7 @@ | |||||
| """Implementation for internal polymorphism `in` operations.""" | """Implementation for internal polymorphism `in` operations.""" | ||||
| from . import _constexpr_utils as const_utils | from . import _constexpr_utils as const_utils | ||||
| from . import _compile_utils as compile_utils | |||||
| from ... import functional as F | from ... import functional as F | ||||
| from ...composite import base | from ...composite import base | ||||
| @@ -99,3 +100,33 @@ def _str_in_dict(x, y): | |||||
| bool, if x in y return true, x not in y return false. | bool, if x in y return true, x not in y return false. | ||||
| """ | """ | ||||
| return F.in_dict(x, y) | return F.in_dict(x, y) | ||||
| @in_.register("Tensor", "List") | |||||
| def _tensor_in_list(x, y): | |||||
| """ | |||||
| Determine if a tensor in a list. | |||||
| Args: | |||||
| x: Tensor | |||||
| y: List | |||||
| Returns: | |||||
| bool, if x in y return true, x not in y return false. | |||||
| """ | |||||
| return compile_utils.tensor_in_sequence(x, y) | |||||
| @in_.register("Tensor", "Tuple") | |||||
| def _tensor_in_tuple(x, y): | |||||
| """ | |||||
| Determine if a tensor in a tuple. | |||||
| Args: | |||||
| x: Tensor | |||||
| y: Tuple | |||||
| Returns: | |||||
| bool, if x in y return true, x not in y return false. | |||||
| """ | |||||
| return compile_utils.tensor_in_sequence(x, y) | |||||
| @@ -16,6 +16,7 @@ | |||||
| """Implementation for internal polymorphism `not in` operations.""" | """Implementation for internal polymorphism `not in` operations.""" | ||||
| from . import _constexpr_utils as const_utils | from . import _constexpr_utils as const_utils | ||||
| from . import _compile_utils as compile_utils | |||||
| from ... import functional as F | from ... import functional as F | ||||
| from ...composite import base | from ...composite import base | ||||
| @@ -99,3 +100,33 @@ def _str_not_in_dict(x, y): | |||||
| bool, if x not in y return true, x in y return false. | bool, if x not in y return true, x in y return false. | ||||
| """ | """ | ||||
| return F.not_in_dict(x, y) | return F.not_in_dict(x, y) | ||||
| @not_in_.register("Tensor", "List") | |||||
| def _tensor_not_in_list(x, y): | |||||
| """ | |||||
| Determine if a tensor not in a list. | |||||
| Args: | |||||
| x: Tensor | |||||
| y: List | |||||
| Returns: | |||||
| bool, if x not in y return true, x in y return false. | |||||
| """ | |||||
| return not compile_utils.tensor_in_sequence(x, y) | |||||
| @not_in_.register("Tensor", "Tuple") | |||||
| def _tensor_not_in_tuple(x, y): | |||||
| """ | |||||
| Determine if a tensor not in a tuple. | |||||
| Args: | |||||
| x: Tensor | |||||
| y: Tuple | |||||
| Returns: | |||||
| bool, if x not in y return true, x in y return false. | |||||
| """ | |||||
| return not compile_utils.tensor_in_sequence(x, y) | |||||
| @@ -30,7 +30,6 @@ from tests.mindspore_test_framework.pipeline.forward.compile_forward \ | |||||
| context.set_context(mode=context.GRAPH_MODE) | context.set_context(mode=context.GRAPH_MODE) | ||||
| grad_all = C.GradOperation(get_all=True) | grad_all = C.GradOperation(get_all=True) | ||||
| @@ -258,6 +257,34 @@ class AxisListDefaultNet(nn.Cell): | |||||
| return self.reduce_sum(x) | return self.reduce_sum(x) | ||||
| class TensorInList(nn.Cell): | |||||
| def __init__(self): | |||||
| super(TensorInList, self).__init__() | |||||
| self.t1 = Tensor(1, mstype.float32) | |||||
| self.t2 = Tensor(2, mstype.float32) | |||||
| def construct(self, x): | |||||
| ret = x | |||||
| list_ = [1, [2, 3], "str", self.t1, self.t2, x] | |||||
| if x in list_: | |||||
| ret = x + x | |||||
| return ret | |||||
| class TensorNotInList(nn.Cell): | |||||
| def __init__(self): | |||||
| super(TensorNotInList, self).__init__() | |||||
| self.t1 = Tensor(1, mstype.float32) | |||||
| self.t2 = Tensor(2, mstype.float32) | |||||
| def construct(self, x): | |||||
| ret = x | |||||
| list_ = [self.t2, x] | |||||
| if self.t1 not in list_: | |||||
| ret = x + x | |||||
| return ret | |||||
| test_case_ops = [ | test_case_ops = [ | ||||
| ('ListOperate', { | ('ListOperate', { | ||||
| 'block': ListOperate(), | 'block': ListOperate(), | ||||
| @@ -275,6 +302,12 @@ test_case_ops = [ | |||||
| ('InList', { | ('InList', { | ||||
| 'block': InListNet(), | 'block': InListNet(), | ||||
| 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), | 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), | ||||
| ('TensorInList', { | |||||
| 'block': TensorInList(), | |||||
| 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), | |||||
| ('TensorNotInList', { | |||||
| 'block': TensorNotInList(), | |||||
| 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))]}), | |||||
| ] | ] | ||||
| test_case_lists = [test_case_ops] | test_case_lists = [test_case_ops] | ||||
| @@ -53,7 +53,7 @@ class NestTupleGraphNet(nn.Cell): | |||||
| class InTupleNet(nn.Cell): | class InTupleNet(nn.Cell): | ||||
| def __init__(self,): | |||||
| def __init__(self): | |||||
| super(InTupleNet, self).__init__() | super(InTupleNet, self).__init__() | ||||
| self.tuple_ = (1, 2, 3, 4, 5, "ok") | self.tuple_ = (1, 2, 3, 4, 5, "ok") | ||||
| @@ -66,6 +66,34 @@ class InTupleNet(nn.Cell): | |||||
| return ret | return ret | ||||
| class TensorInTuple(nn.Cell): | |||||
| def __init__(self): | |||||
| super(TensorInTuple, self).__init__() | |||||
| self.t1 = Tensor(1, mstype.float32) | |||||
| self.t2 = Tensor(2, mstype.float32) | |||||
| self.tuple_ = (self.t1, self.t2) | |||||
| def construct(self, x): | |||||
| ret = x | |||||
| if self.t1 in self.tuple_: | |||||
| ret = x + x | |||||
| return ret | |||||
| class TensorNotInTuple(nn.Cell): | |||||
| def __init__(self): | |||||
| super(TensorNotInTuple, self).__init__() | |||||
| self.t1 = Tensor(1, mstype.float32) | |||||
| self.t2 = Tensor(2, mstype.float32) | |||||
| self.tuple_ = (self.t1, self.t2) | |||||
| def construct(self, x): | |||||
| ret = x | |||||
| if self.t1 not in self.tuple_: | |||||
| ret = x + x | |||||
| return ret | |||||
| test_case_ops = [ | test_case_ops = [ | ||||
| ('TupleGraph', { | ('TupleGraph', { | ||||
| 'block': TupleGraphNet(), | 'block': TupleGraphNet(), | ||||
| @@ -75,7 +103,13 @@ test_case_ops = [ | |||||
| 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), | 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), | ||||
| ('InTuple', { | ('InTuple', { | ||||
| 'block': InTupleNet(), | 'block': InTupleNet(), | ||||
| 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}) | |||||
| 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), | |||||
| ('TensorInTuple', { | |||||
| 'block': TensorInTuple(), | |||||
| 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), | |||||
| ('TensorNotInTuple', { | |||||
| 'block': TensorNotInTuple(), | |||||
| 'desc_inputs': [Tensor(np.ones((3, 3, 24, 24)), mstype.float32)]}), | |||||
| ] | ] | ||||
| test_case_lists = [test_case_ops] | test_case_lists = [test_case_ops] | ||||