| @@ -129,7 +129,7 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name): | |||
| return tuple_index_new | |||
| def _expand_data_dims(data, tuple_index, op_name): | |||
| def _expand_data_dims(data, tuple_index): | |||
| """expand the data's dim with 'None' and 'Boolean' in tuple_index""" | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| expand_positions, tuple_index_new = (), () | |||
| @@ -203,8 +203,14 @@ def tensor_index_by_list(data, list_index): | |||
| indexes_types = hyper_map(F.typeof, list_index) | |||
| if const_utils.judge_indexes_types(indexes_types, mstype.int_type + (mstype.bool_,)): | |||
| sub_tuple_index = const_utils.transform_sequence_index(list_index, data_shape[0], const_utils.TENSOR_GETITEM) | |||
| if not sub_tuple_index: | |||
| data_rank = len(data_shape) | |||
| if data_rank == 1: | |||
| return const_utils.make_tensor([], data.dtype, ()) | |||
| return const_utils.make_tensor([], data.dtype, data_shape[1:]) | |||
| tensor_index = const_utils.make_tensor(sub_tuple_index, mstype.int64) | |||
| return F.gather(data, tensor_index, 0) | |||
| tuple_index_new = () | |||
| for index in list_index: | |||
| tuple_index_new += (index,) | |||
| @@ -219,7 +225,7 @@ def tensor_index_by_tuple(data, tuple_index): | |||
| op_name = const_utils.TENSOR_GETITEM | |||
| tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) | |||
| data, tuple_index = _expand_data_dims(data, tuple_index, op_name) | |||
| data, tuple_index = _expand_data_dims(data, tuple_index) | |||
| data_shape = F.shape(data) | |||
| data_rank = len(data_shape) | |||
| @@ -228,6 +234,7 @@ def tensor_index_by_tuple(data, tuple_index): | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name) | |||
| if contain_type == const_utils.ALL_TENSOR: | |||
| return _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name) | |||
| if contain_type == const_utils.ALL_BASIC: | |||
| @@ -245,7 +252,9 @@ def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name): | |||
| tensor_index_shape = hyper_map(F.shape, tuple_index) | |||
| broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name) | |||
| if 0 in broadcast_shape: | |||
| res_shape = broadcast_shape + data_shape[tuple_index_len:] | |||
| res_shape = broadcast_shape | |||
| if tuple_index_len < len(data_shape): | |||
| res_shape += data_shape[tuple_index_len:] | |||
| res = const_utils.make_tensor([], data.dtype, res_shape) | |||
| return res | |||
| @@ -268,12 +277,68 @@ def _tensor_getitem_by_tuple_slice(data, tuple_index): | |||
| def _tensor_getitem_by_tuple(data, tuple_index, op_name): | |||
| """Tensor getitem by a tuple of mixed tensor.""" | |||
| indices = _generate_indices_from_tuple(data, tuple_index, op_name) | |||
| data_shape = F.shape(data) | |||
| data_rank = len(data_shape) | |||
| tuple_index_len = len(tuple_index) | |||
| tensor_indexes, slice_indexes = [], [] | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \ | |||
| const_utils.get_pos_of_indexes_types(indexes_types, op_name) | |||
| tuple_index_new = () | |||
| for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)): | |||
| if i in int_positions: | |||
| int_index = const_utils.check_and_transform_int_index(index, dim_size, op_name) | |||
| tensor_index = F.scalar_to_tensor(int_index, mstype.int64) | |||
| tuple_index_new += (tensor_index,) | |||
| tensor_indexes.append(tensor_index) | |||
| tensor_positions.append(i) | |||
| elif i in sequence_positions: | |||
| sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name) | |||
| tensor_index = const_utils.make_tensor(sequence_index) | |||
| tensor_index = F.cast(tensor_index, mstype.int64) | |||
| tuple_index_new += (tensor_index,) | |||
| tensor_indexes.append(tensor_index) | |||
| tensor_positions.append(i) | |||
| elif i in tensor_positions: | |||
| const_utils.check_index_type_valid(F.dtype(index), mstype.int_type, op_name) | |||
| tensor_index = F.cast(index, mstype.int64) | |||
| tuple_index_new += (tensor_index,) | |||
| tensor_indexes.append(tensor_index) | |||
| elif i in slice_positions: | |||
| slice_indexes.append(index) | |||
| tuple_index_new += (index,) | |||
| tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) | |||
| tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) | |||
| indexes_types = hyper_map(F.typeof, tuple_index_new) | |||
| broadcast_shape, final_shape, indexes_shapes_info = const_utils.generate_index_info_from_tuple_of_mixed_tensors( | |||
| data_shape, indexes_types, tensor_indexes_shapes, tensor_indexes_dtypes, slice_indexes, op_name) | |||
| if 0 in final_shape: | |||
| if tuple_index_len < data_rank: | |||
| final_shape = final_shape + data_shape[tuple_index_len:] | |||
| return const_utils.make_tensor([], data.dtype, final_shape) | |||
| slice_number = 0 | |||
| final_index_tensors = [] | |||
| index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) | |||
| for i in range(tuple_index_len): | |||
| if i in tensor_positions: | |||
| transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape, | |||
| tuple_index_new[i]) | |||
| final_index_tensors.append(transform_tensor) | |||
| if i in slice_positions: | |||
| slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name) | |||
| final_index_tensors.append(slice_tensor) | |||
| slice_number += 1 | |||
| indices = pack(final_index_tensors) | |||
| result = F.gather_nd(data, indices) | |||
| return result | |||
| def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): | |||
| def _generate_indices_from_tuple_of_tensor(tuple_index, op_name): | |||
| """Generate an indices tensor from a tuple of tensor.""" | |||
| indices = None | |||
| indexes_types = hyper_map(F.dtype, tuple_index) | |||
| @@ -510,13 +575,13 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value): | |||
| return data | |||
| op_name = const_utils.TENSOR_GETITEM | |||
| tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) | |||
| data, tuple_index = _expand_data_dims(data, tuple_index, op_name) | |||
| data, tuple_index = _expand_data_dims(data, tuple_index) | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) | |||
| if contain_type == const_utils.ALL_TENSOR: | |||
| indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM) | |||
| indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM) | |||
| else: | |||
| int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) | |||
| if int_cnt == const_utils.ALL_INT: | |||
| @@ -572,13 +637,13 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): | |||
| return data | |||
| op_name = const_utils.TENSOR_GETITEM | |||
| tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) | |||
| data, tuple_index = _expand_data_dims(data, tuple_index, op_name) | |||
| data, tuple_index = _expand_data_dims(data, tuple_index) | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) | |||
| if contain_type == const_utils.ALL_TENSOR: | |||
| indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM) | |||
| indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM) | |||
| else: | |||
| int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) | |||
| if int_cnt == const_utils.ALL_INT: | |||
| @@ -600,13 +665,13 @@ def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): | |||
| return data | |||
| op_name = const_utils.TENSOR_GETITEM | |||
| tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) | |||
| data, tuple_index = _expand_data_dims(data, tuple_index, op_name) | |||
| data, tuple_index = _expand_data_dims(data, tuple_index) | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) | |||
| if contain_type == const_utils.ALL_TENSOR: | |||
| indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM) | |||
| indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM) | |||
| else: | |||
| int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) | |||
| if int_cnt == const_utils.ALL_INT: | |||