| @@ -57,6 +57,68 @@ def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): | |||
| return indices | |||
| def _generate_indices_from_tuple(data, tuple_index, op_name): | |||
| """Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor.""" | |||
| data_shape = F.shape(data) | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| int_positions, sequence_positions = const_utils.get_pos_of_int_sequence(indexes_types) | |||
| tuple_index_new = () | |||
| tuple_len = len(tuple_index) | |||
| for i in range(tuple_len): | |||
| index = tuple_index[i] | |||
| shape = data_shape[i] | |||
| if i in int_positions: | |||
| int_index = const_utils.check_and_transform_int_index(index, shape, op_name) | |||
| tensor_index = F.scalar_to_tensor(int_index, mstype.int32) | |||
| tuple_index_new += (tensor_index,) | |||
| elif i in sequence_positions: | |||
| sequence_index = const_utils.transform_sequence_index(index, shape, op_name) | |||
| tensor_index = F.tuple_to_array(sequence_index) | |||
| tuple_index_new += (tensor_index,) | |||
| else: | |||
| tuple_index_new += (index,) | |||
| indexes_types = hyper_map(F.typeof, tuple_index_new) | |||
| tensor_positions, slice_positions, ellipsis_position = \ | |||
| const_utils.separate_mixed_tensors_index(indexes_types, op_name) | |||
| tensor_indexes = [] | |||
| slice_indexes = [] | |||
| for i in tensor_positions: | |||
| tensor_indexes.append(tuple_index_new[i]) | |||
| for j in slice_positions: | |||
| slice_indexes.append(tuple_index_new[j]) | |||
| tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) | |||
| tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) | |||
| broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \ | |||
| const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape, | |||
| indexes_types, | |||
| tensor_indexes_shapes, | |||
| tensor_indexes_dtypes, | |||
| slice_indexes, | |||
| op_name) | |||
| slice_number = 0 | |||
| final_index_tensors = [] | |||
| tuple_index_size = len(tuple_index_new) | |||
| index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) | |||
| for i in range(tuple_index_size): | |||
| if i in tensor_positions: | |||
| transform_tensor = _transform_indexing_tensor( | |||
| broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i]) | |||
| final_index_tensors.append(transform_tensor) | |||
| if i in slice_positions: | |||
| slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name) | |||
| final_index_tensors.append(slice_tensor) | |||
| slice_number += 1 | |||
| if i == ellipsis_position: | |||
| ellipsis_tensors = const_utils.convert_ellipsis_to_tensors( | |||
| slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name) | |||
| for ele in ellipsis_tensors: | |||
| final_index_tensors.append(ele) | |||
| slice_number += ellipsis_occupied_dims | |||
| indices = pack(final_index_tensors) | |||
| return indices | |||
| def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): | |||
| """Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor.""" | |||
| data_shape = F.shape(data) | |||
| @@ -160,6 +222,8 @@ def _tensor_getitem(self, index): | |||
| return tensor_index_by_tensor(self, index) | |||
| if isinstance(index, tuple): | |||
| return tensor_index_by_tuple(self, index) | |||
| if isinstance(index, list): | |||
| return tensor_index_by_list(self, index) | |||
| # bool type should be judged before int | |||
| if isinstance(index, bool): | |||
| return _tensor_index_by_bool(self, index) | |||
| @@ -187,6 +251,13 @@ def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): | |||
| return result | |||
| def _tensor_getitem_by_tuple(data, tuple_index): | |||
| """Tensor getitem by a tuple of mixed tensor.""" | |||
| indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_GETITEM) | |||
| result = F.gather_nd(data, indices) | |||
| return result | |||
| def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): | |||
| """Tensor getitem by a tuple of mixed tensor.""" | |||
| indices = _generate_indices_from_tuple_of_mixed_tensors(data, | |||
| @@ -273,12 +344,12 @@ def tensor_index_by_tuple(data, tuple_index): | |||
| if len(tuple_index) == 1: | |||
| return data[tuple_index_without_none[0]] | |||
| indexes_types = hyper_map(F.typeof, tuple_index_without_none) | |||
| tensor_cnt = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_GETITEM) | |||
| if tensor_cnt == const_utils.NO_TENSOR: | |||
| return _tensor_index_by_tuple_slice(data, tuple_index_without_none) | |||
| if tensor_cnt == const_utils.ALL_TENSOR: | |||
| return _tensor_getitem_by_tuple_of_tensor(data, tuple_index_without_none) | |||
| return _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index_without_none) | |||
| contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_GETITEM) | |||
| if contain_type == const_utils.ALL_TENSOR: | |||
| return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) | |||
| if contain_type == const_utils.ALL_BASIC: | |||
| return _tensor_index_by_tuple_slice(data, tuple_index) | |||
| return _tensor_getitem_by_tuple(data, tuple_index_without_none) | |||
| def _tensor_setitem(self, index, value): | |||
| @@ -31,6 +31,8 @@ ALL_SCALAR = 3 | |||
| ALL_INT = 4 | |||
| NO_INT = 5 | |||
| CONTAIN_INT = 6 | |||
| ALL_BASIC = 7 | |||
| MIXED = 8 | |||
| INT_ = 0 | |||
| BOOL_ = 1 | |||
| @@ -307,6 +309,18 @@ def tuple_index_int_cnt(types, op_name): | |||
| return ALL_INT if int_cnt == len(types) else NO_INT if int_cnt == 0 else CONTAIN_INT | |||
| @constexpr | |||
| def tuple_index_type_cnt(types, op_name): | |||
| """count the tensor type of types which contains the tuple elements' type.""" | |||
| tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types) | |||
| basic_cnt = sum(isinstance(ele, (mstype.Int, mstype.ellipsis_type, mstype.slice_type)) for ele in types) | |||
| if tensor_cnt == len(types): | |||
| return ALL_TENSOR | |||
| if basic_cnt == len(types): | |||
| return ALL_BASIC | |||
| return MIXED | |||
| @constexpr | |||
| def check_value_elements(data_dtype, types): | |||
| """Judges the type of all elements of the tuple.""" | |||
| @@ -501,6 +515,34 @@ def convert_ellipsis_to_tensors(slice_number, | |||
| return tensor_list | |||
| @constexpr | |||
| def check_and_transform_int_index(index, shape, op_name): | |||
| if index < -shape or index >= shape: | |||
| raise IndexError(f"In the \"{op_name}\", the index should in the range [-{shape}, {shape-1}] to fit " | |||
| f"the corresponding dim length, but get {index}.") | |||
| if index < 0: | |||
| index += shape | |||
| return index | |||
| @constexpr | |||
| def transform_sequence_index(sequence_index, shape, op_name): | |||
| """transform list or tuple with integer and boolean to tuple with integer index""" | |||
| bool_count = len(list(filter(lambda index: isinstance(index, bool), sequence_index))) | |||
| int_count = len(list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count | |||
| if int_count == 0: | |||
| if bool_count == shape: | |||
| list_index = list(filter(lambda i: sequence_index[i], range(bool_count))) | |||
| else: | |||
| raise IndexError("The boolean array should have the same length with the corresponding dimensiton") | |||
| else: | |||
| list_index = [int(index) for index in sequence_index] | |||
| for i, index in enumerate(list_index): | |||
| list_index[i] = check_and_transform_int_index(index, shape, op_name) | |||
| sub_tuple_index = tuple(list_index) | |||
| return sub_tuple_index | |||
| @constexpr | |||
| def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name): | |||
| """Convert a slice to a tensor.""" | |||
| @@ -702,6 +744,18 @@ def get_pos_of_int_index(indexes_types): | |||
| return int_positions | |||
| @constexpr | |||
| def get_pos_of_int_sequence(indexes_types): | |||
| """Get int and sequence index positions from the mixed tensors index.""" | |||
| int_positions, sequence_positions = [], [] | |||
| for i, index_type in enumerate(indexes_types): | |||
| if isinstance(index_type, mstype.Int): | |||
| int_positions.append(i) | |||
| elif isinstance(index_type, (tuple, list)): | |||
| sequence_positions.append(i) | |||
| return int_positions, sequence_positions | |||
| @constexpr | |||
| def separate_mixed_tensors_index(indexes_types, op_name): | |||
| """Separate the position information of tensor and slice and ellipsis from the mixed tensors index.""" | |||
| @@ -206,21 +206,6 @@ def _tensor_getitem_by_tensor(data, tensor_index): | |||
| return compile_utils.tensor_index_by_tensor(data, tensor_index) | |||
| @getitem.register("Tensor", "Tuple") | |||
| def _tensor_getitem_by_tuple(data, tuple_index): | |||
| """ | |||
| Getting item of tensor by tuple. | |||
| Inputs: | |||
| data (Tensor): A tensor. | |||
| tuple_index (tuple): Index in tuple which include ellipsis, slice, int, Tensor, None, list, tuple. | |||
| Outputs: | |||
| Tensor, element type is the same as the element type of data. | |||
| """ | |||
| return compile_utils.tensor_index_by_tuple(data, tuple_index) | |||
| @getitem.register("Tensor", "Ellipsis") | |||
| def _tensor_getitem_by_ellipsis(data, ellipsis_index): | |||
| """ | |||
| @@ -249,3 +234,18 @@ def _tensor_getitem_by_list(data, list_index): | |||
| Tensor ,same as data. | |||
| """ | |||
| return compile_utils.tensor_index_by_list(data, list_index) | |||
| @getitem.register("Tensor", "Tuple") | |||
| def _tensor_getitem_by_tuple(data, tuple_index): | |||
| """ | |||
| Getting item of tensor by tuple. | |||
| Inputs: | |||
| data (Tensor): A tensor. | |||
| tuple_index (tuple): Index in tuple which include ellipsis, slice, int, Tensor, None, list, tuple. | |||
| Outputs: | |||
| Tensor, element type is the same as the element type of data. | |||
| """ | |||
| return compile_utils.tensor_index_by_tuple(data, tuple_index) | |||
| @@ -21,61 +21,64 @@ from mindspore import dtype as mstype | |||
| from mindspore.nn import Cell | |||
| class NetWorkFancyIndexBoolean(Cell): | |||
| class NetWorkFancyIndex(Cell): | |||
| def __init__(self, index): | |||
| super(NetWorkFancyIndexBoolean, self).__init__() | |||
| super(NetWorkFancyIndex, self).__init__() | |||
| self.index = index | |||
| def construct(self, tensor): | |||
| return tensor[self.index] | |||
| class NetWorkFancyIndexInterger(Cell): | |||
| def __init__(self, index): | |||
| super(NetWorkFancyIndexInterger, self).__init__() | |||
| self.index = index | |||
| def construct(self, tensor): | |||
| return tensor[self.index] | |||
| class NetWorkFancyIndexIntergerBooleanMixed(Cell): | |||
| def __init__(self, index): | |||
| super(NetWorkFancyIndexIntergerBooleanMixed, self).__init__() | |||
| self.index = index | |||
| def construct(self, tensor): | |||
| return tensor[self.index] | |||
| def test_tensor_fancy_index_integer_list(): | |||
| def test_tensor_fancy_index_integer_list_graph(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| index = [0, 2, 1] | |||
| net = NetWorkFancyIndexBoolean(index) | |||
| net = NetWorkFancyIndex(index) | |||
| input_np = np.arange(60).reshape(3, 4, 5) | |||
| input_me = Tensor(input_np, dtype=mstype.float32) | |||
| output_me = net(input_me).asnumpy() | |||
| output_np = input_np[index] | |||
| assert np.allclose(output_np, output_me, 0, 0) | |||
| net(input_me) | |||
| def test_tensor_fancy_boolean_list(): | |||
| def test_tensor_fancy_boolean_list_graph(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| index = [True, True, False] | |||
| net = NetWorkFancyIndexInterger(index) | |||
| net = NetWorkFancyIndex(index) | |||
| input_np = np.arange(60).reshape(3, 4, 5) | |||
| input_me = Tensor(input_np, dtype=mstype.float32) | |||
| output_me = net(input_me).asnumpy() | |||
| output_np = input_np[index] | |||
| assert np.allclose(output_np, output_me, 0, 0) | |||
| net(input_me) | |||
| def test_tensor_fancy_integer_boolean_list(): | |||
| def test_tensor_fancy_integer_boolean_list_graph(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| index = [1, 2, True, False] | |||
| net = NetWorkFancyIndexIntergerBooleanMixed(index) | |||
| net = NetWorkFancyIndex(index) | |||
| input_np = np.arange(60).reshape(3, 4, 5) | |||
| input_me = Tensor(input_np, dtype=mstype.float32) | |||
| output_me = net(input_me).asnumpy() | |||
| output_np = input_np[index] | |||
| assert np.allclose(output_np, output_me, 0, 0) | |||
| net(input_me) | |||
| def test_tensor_fancy_integer_list_mixed_graph(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| index = (1, [2, 1, 3], slice(1, 3, 1), ..., 4) | |||
| net = NetWorkFancyIndex(index) | |||
| input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8) | |||
| input_me = Tensor(input_np, dtype=mstype.float32) | |||
| net(input_me) | |||
| def test_tensor_fancy_integer_tuple_mixed_graph(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| index = (1, (2, 1, 3), slice(1, 3, 1), ..., 4) | |||
| net = NetWorkFancyIndex(index) | |||
| input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8) | |||
| input_me = Tensor(input_np, dtype=mstype.float32) | |||
| net(input_me) | |||
| def test_tensor_fancy_integer_list_tuple_mixed_graph(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| index = (1, [2, 1, 3], (3, 2, 1), slice(1, 3, 1), ..., 4) | |||
| net = NetWorkFancyIndex(index) | |||
| input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8) | |||
| input_me = Tensor(input_np, dtype=mstype.float32) | |||
| net(input_me) | |||