| @@ -28,23 +28,12 @@ pack = P.Stack(axis=-1) | |||
| def _tensor_getitem(self, index): | |||
| """Handle tensor getitem""" | |||
| if isinstance(index, Tensor): | |||
| return tensor_index_by_tensor(self, index) | |||
| if isinstance(index, list): | |||
| return tensor_index_by_list(self, index) | |||
| if isinstance(index, tuple): | |||
| return tensor_index_by_tuple(self, index) | |||
| # bool type should be judged before int | |||
| if isinstance(index, bool): | |||
| return _tensor_index_by_bool(self, index) | |||
| if isinstance(index, int): | |||
| return _tensor_index_by_integer(self, index) | |||
| if isinstance(index, slice): | |||
| return tensor_index_by_slice(self, index) | |||
| if index is None: | |||
| return F.expand_dims(self, 0) | |||
| if index is ...: | |||
| return self | |||
| if isinstance(index, (Tensor, int, slice)) or index in (None, ...): | |||
| return tensor_index_by_tuple(self, (index,)) | |||
| raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor with int, " | |||
| f"list and tuple ,but got {index} with type {type(index)}.") | |||
| @@ -149,17 +138,7 @@ def _expand_data_dims(data, tuple_index): | |||
| return data, tuple_index_new | |||
| def tensor_index_by_slice(data, slice_index): | |||
| """Tensor getitem by a single slice""" | |||
| shape = F.shape(data) | |||
| if not shape: | |||
| const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor" | |||
| "cannot be 0.") | |||
| begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(shape, slice_index) | |||
| return F.strided_slice(data, begin_strides, end_strides, step_strides) | |||
| def tensor_index_by_number(data, number): | |||
| def tensor_index_by_number(data, number_index): | |||
| """Tensor getitem by a Number which may be integer/float/bool value""" | |||
| data_type = F.typeof(data) | |||
| if const_utils.judge_index_type(data_type, mstype.tensor_type): | |||
| @@ -168,42 +147,35 @@ def tensor_index_by_number(data, number): | |||
| min_data_rank, max_data_rank = 0, 8 | |||
| const_utils.judge_data_rank(data_rank, min_data_rank, max_data_rank) | |||
| number_type = const_utils.check_number_index_type(number) | |||
| number_type = const_utils.check_number_index_type(number_index) | |||
| if number_type == const_utils.BOOL_: | |||
| return _tensor_index_by_bool(data, number) | |||
| return tensor_index_by_tuple(data, (number_index,)) | |||
| if number_type == const_utils.INT_: | |||
| return _tensor_index_by_integer(data, number) | |||
| return _tensor_index_by_integer(data, number_index) | |||
| return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.") | |||
| def _tensor_index_by_bool(data, bool_value): | |||
| # TODO wait to remove after setitem by Yang Linfeng | |||
| def _tensor_index_by_bool(data, bool_index): | |||
| """Tensor getitem by a single bool value""" | |||
| if bool_value: | |||
| if bool_index: | |||
| return F.expand_dims(data, 0) | |||
| return const_utils.make_tensor([], data.dtype, (0,) + F.shape(data)) | |||
| def _tensor_index_by_integer(data, number): | |||
| def _tensor_index_by_integer(data, int_index): | |||
| """Tensor getitem by a single integer number""" | |||
| data_shape = F.shape(data) | |||
| data_rank = len(data_shape) | |||
| if data_rank == 0: | |||
| return const_utils.raise_type_error("When tensor is indexed by an integer, the dimension of the tensor " | |||
| "cannot be 0.") | |||
| transformed_number = const_utils.check_and_transform_int_index(number, data_shape[0], const_utils.TENSOR_GETITEM) | |||
| transformed_number = const_utils.check_and_transform_int_index(int_index, data_shape[0], const_utils.TENSOR_GETITEM) | |||
| begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(data_shape, transformed_number) | |||
| shrink_axis_mask = 1 | |||
| return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) | |||
| def tensor_index_by_tensor(data, tensor_index): | |||
| """Tensor getitem by a single tensor""" | |||
| index_type = F.dtype(tensor_index) | |||
| const_utils.check_index_type_valid(index_type, mstype.int_type, const_utils.TENSOR_GETITEM) | |||
| tensor_index = F.cast(tensor_index, mstype.int64) | |||
| return F.gather(data, tensor_index, 0) | |||
| def tensor_index_by_list(data, list_index): | |||
| """Tensor getitem by list of int and bool""" | |||
| data_shape = F.shape(data) | |||
| @@ -241,39 +213,11 @@ def tensor_index_by_tuple(data, tuple_index): | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name) | |||
| if contain_type == const_utils.ALL_TENSOR: | |||
| return _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name) | |||
| if contain_type == const_utils.ALL_BASIC: | |||
| return _tensor_getitem_by_tuple_slice(data, tuple_index) | |||
| return _tensor_getitem_by_tuple(data, tuple_index, op_name) | |||
| def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name): | |||
| """Tensor getitem by a tuple of tensor.""" | |||
| data_shape = F.shape(data) | |||
| tuple_index_len = len(tuple_index) | |||
| indexes_types = hyper_map(F.dtype, tuple_index) | |||
| const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name) | |||
| tensor_index_shape = hyper_map(F.shape, tuple_index) | |||
| broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name) | |||
| if 0 in broadcast_shape: | |||
| res_shape = broadcast_shape | |||
| if tuple_index_len < len(data_shape): | |||
| res_shape += data_shape[tuple_index_len:] | |||
| res = const_utils.make_tensor([], data.dtype, res_shape) | |||
| return res | |||
| broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index) | |||
| new_broadcast_tensors = () | |||
| for tensor in broadcast_tensors: | |||
| new_broadcast_tensors += (F.cast(tensor, mstype.int64),) | |||
| indices = pack(new_broadcast_tensors) | |||
| result = F.gather_nd(data, indices) | |||
| return result | |||
| def _tensor_getitem_by_tuple_slice(data, tuple_index): | |||
| """Tensor getitem by a tuple of slice""" | |||
| data_shape = F.shape(data) | |||
| @@ -291,7 +235,7 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name): | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \ | |||
| const_utils.get_pos_of_indexes_types(indexes_types, op_name) | |||
| tuple_index_new = () | |||
| tuple_index_new, slice_shapes = (), () | |||
| for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)): | |||
| if i in int_positions: | |||
| @@ -299,57 +243,56 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name): | |||
| tensor_index = F.scalar_to_tensor(int_index, mstype.int64) | |||
| tuple_index_new += (tensor_index,) | |||
| tensor_indexes.append(tensor_index) | |||
| tensor_positions.append(i) | |||
| tensor_positions += (i,) | |||
| elif i in sequence_positions: | |||
| sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name) | |||
| tensor_index = const_utils.make_tensor(sequence_index) | |||
| tensor_index = F.cast(tensor_index, mstype.int64) | |||
| tuple_index_new += (tensor_index,) | |||
| tensor_indexes.append(tensor_index) | |||
| tensor_positions.append(i) | |||
| tensor_positions += (i,) | |||
| elif i in tensor_positions: | |||
| const_utils.check_index_type_valid(F.dtype(index), mstype.int_type, op_name) | |||
| const_utils.check_type_valid(F.dtype(index), mstype.int_type, op_name) | |||
| tensor_index = F.cast(index, mstype.int64) | |||
| tuple_index_new += (tensor_index,) | |||
| tensor_indexes.append(tensor_index) | |||
| elif i in slice_positions: | |||
| slice_indexes.append(index) | |||
| tuple_index_new += (index,) | |||
| slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size) | |||
| slice_shapes += (len(slice_ele_list_index),) | |||
| tuple_index_new += (slice_ele_list_index,) | |||
| slice_indexes.append(slice_ele_list_index) | |||
| tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) | |||
| tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) | |||
| indexes_types = hyper_map(F.typeof, tuple_index_new) | |||
| broadcast_shape, final_shape, indexes_shapes_info = const_utils.generate_index_info_from_tuple_of_mixed_tensors( | |||
| data_shape, indexes_types, tensor_indexes_shapes, tensor_indexes_dtypes, slice_indexes, op_name) | |||
| broadcast_shape, index_tensor_new_shape, final_shape, fancy_position = \ | |||
| const_utils.generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_indexes_shapes, | |||
| slice_shapes, op_name) | |||
| if 0 in final_shape: | |||
| if 0 in final_shape + data_shape: | |||
| if tuple_index_len < data_rank: | |||
| final_shape = final_shape + data_shape[tuple_index_len:] | |||
| return const_utils.make_tensor([], data.dtype, final_shape) | |||
| slice_number = 0 | |||
| final_index_tensors = [] | |||
| index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) | |||
| for i in range(tuple_index_len): | |||
| slice_cnt = 0 | |||
| for i, index in enumerate(tuple_index_new): | |||
| if i in tensor_positions: | |||
| transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape, | |||
| tuple_index_new[i]) | |||
| transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape, index) | |||
| final_index_tensors.append(transform_tensor) | |||
| if i in slice_positions: | |||
| slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name) | |||
| final_index_tensors.append(slice_tensor) | |||
| slice_number += 1 | |||
| elif i in slice_positions: | |||
| slice_index_tensor = const_utils.convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, | |||
| slice_shapes, fancy_position) | |||
| final_index_tensors.append(slice_index_tensor) | |||
| slice_cnt += 1 | |||
| indices = pack(final_index_tensors) | |||
| result = F.gather_nd(data, indices) | |||
| return result | |||
| return F.gather_nd(data, indices) | |||
| def _generate_indices_from_tuple_of_tensor(tuple_index, op_name): | |||
| """Generate an indices tensor from a tuple of tensor.""" | |||
| indices = None | |||
| indexes_types = hyper_map(F.dtype, tuple_index) | |||
| const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name) | |||
| const_utils.check_types_valid(indexes_types, mstype.int_type, op_name) | |||
| tensor_index_shape = hyper_map(F.shape, tuple_index) | |||
| broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name) | |||
| broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index) | |||
| @@ -363,12 +306,11 @@ def _generate_indices_from_tuple_of_tensor(tuple_index, op_name): | |||
| def _generate_indices_from_tuple(data, tuple_index, op_name): | |||
| """Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor.""" | |||
| data_shape = F.shape(data) | |||
| tuple_index_len = len(tuple_index) | |||
| tensor_indexes, slice_indexes = [], [] | |||
| indexes_types = hyper_map(F.typeof, tuple_index) | |||
| slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \ | |||
| const_utils.get_pos_of_indexes_types(indexes_types, op_name) | |||
| tuple_index_new = () | |||
| tuple_index_new, slice_shapes = (), () | |||
| for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)): | |||
| if i in int_positions: | |||
| @@ -376,41 +318,41 @@ def _generate_indices_from_tuple(data, tuple_index, op_name): | |||
| tensor_index = F.scalar_to_tensor(int_index, mstype.int64) | |||
| tuple_index_new += (tensor_index,) | |||
| tensor_indexes.append(tensor_index) | |||
| tensor_positions.append(i) | |||
| tensor_positions += (i,) | |||
| elif i in sequence_positions: | |||
| sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name) | |||
| tensor_index = const_utils.make_tensor(sequence_index) | |||
| tensor_index = F.cast(tensor_index, mstype.int64) | |||
| tuple_index_new += (tensor_index,) | |||
| tensor_indexes.append(tensor_index) | |||
| tensor_positions.append(i) | |||
| tensor_positions += (i,) | |||
| elif i in tensor_positions: | |||
| const_utils.check_index_type_valid(F.dtype(index), mstype.int_type, op_name) | |||
| const_utils.check_type_valid(F.dtype(index), mstype.int_type, op_name) | |||
| tensor_index = F.cast(index, mstype.int64) | |||
| tuple_index_new += (tensor_index,) | |||
| tensor_indexes.append(tensor_index) | |||
| elif i in slice_positions: | |||
| slice_indexes.append(index) | |||
| tuple_index_new += (index,) | |||
| slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size) | |||
| slice_shapes += (len(slice_ele_list_index),) | |||
| tuple_index_new += (slice_ele_list_index,) | |||
| slice_indexes.append(slice_ele_list_index) | |||
| tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) | |||
| tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) | |||
| indexes_types = hyper_map(F.typeof, tuple_index_new) | |||
| broadcast_shape, final_shape, indexes_shapes_info = const_utils.generate_index_info_from_tuple_of_mixed_tensors( | |||
| data_shape, indexes_types, tensor_indexes_shapes, tensor_indexes_dtypes, slice_indexes, op_name) | |||
| broadcast_shape, index_tensor_new_shape, final_shape, fancy_position = \ | |||
| const_utils.generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_indexes_shapes, | |||
| slice_shapes, op_name) | |||
| slice_number = 0 | |||
| final_index_tensors = [] | |||
| index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) | |||
| for i in range(tuple_index_len): | |||
| slice_cnt = 0 | |||
| for i, index in enumerate(tuple_index_new): | |||
| if i in tensor_positions: | |||
| transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape, | |||
| tuple_index_new[i]) | |||
| transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape, index) | |||
| final_index_tensors.append(transform_tensor) | |||
| if i in slice_positions: | |||
| slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name) | |||
| final_index_tensors.append(slice_tensor) | |||
| slice_number += 1 | |||
| elif i in slice_positions: | |||
| slice_index_tensor = const_utils.convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, | |||
| slice_shapes, fancy_position) | |||
| final_index_tensors.append(slice_index_tensor) | |||
| slice_cnt += 1 | |||
| indices = pack(final_index_tensors) | |||
| return indices | |||
| @@ -530,10 +472,8 @@ def tensor_setitem_by_tensor_with_number(data, index, value): | |||
| def tensor_setitem_by_tensor_with_tuple(data, index, value): | |||
| """Assigns the tensor by tensor with tuple value.""" | |||
| index_dtype = F.dtype(index) | |||
| check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM) | |||
| result = None | |||
| if check_dtype: | |||
| result = _tensor_setitem_by_tensor_with_tuple(data, index, value) | |||
| const_utils.check_type_valid(index_dtype, (mstype.int32, mstype.int64), const_utils.TENSOR_SETITEM) | |||
| result = _tensor_setitem_by_tensor_with_tuple(data, index, value) | |||
| return result | |||
| @@ -151,20 +151,6 @@ def judge_index_type(index_type, target_type): | |||
| return False | |||
| @constexpr | |||
| def check_type_valid(dtype, target_type, op_name): | |||
| if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): | |||
| raise TypeError( | |||
| f"The '{op_name}' doesn't supoort {dtype}' and expect to receive {target_type}.") | |||
| @constexpr | |||
| def check_index_type_valid(dtype, target_type, op_name): | |||
| if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): | |||
| raise IndexError( | |||
| f"The '{op_name}' doesn't supoort {dtype}' and expect to receive {target_type}.") | |||
| @constexpr | |||
| def judge_indexes_types(dtypes, target_type): | |||
| """Check a tuple of tensor data type.""" | |||
| @@ -175,37 +161,45 @@ def judge_indexes_types(dtypes, target_type): | |||
| @constexpr | |||
| def check_indexes_types_valid(dtypes, target_type, op_name): | |||
| def check_type_valid(dtype, target_type, op_name): | |||
| if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): | |||
| if op_name in (TENSOR_GETITEM, TENSOR_SETITEM): | |||
| raise IndexError( | |||
| f"The '{op_name}' doesn't supoort {dtype}' and expect to receive {target_type}.") | |||
| raise TypeError( | |||
| f"The '{op_name}' doesn't supoort {dtype}' and expect to receive {target_type}.") | |||
| @constexpr | |||
| def check_types_valid(dtypes, target_type, op_name): | |||
| """Check a tuple of tensor data type.""" | |||
| for dtype in dtypes: | |||
| if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): | |||
| raise IndexError(f"For '{op_name}', the all index tensor data types should be in {target_type}, " | |||
| f"but got {dtype}.") | |||
| check_type_valid(dtype, target_type, op_name) | |||
| @constexpr | |||
| def get_pos_of_indexes_types(indexes_types, op_name): | |||
| """Separate the position information of tensor and slice and ellipsis from the mixed tensors index.""" | |||
| slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, tensor_positions, \ | |||
| sequence_positions = [], [], [], [], [], [], [] | |||
| sequence_positions = (), (), (), (), (), (), () | |||
| for i, index_type in enumerate(indexes_types): | |||
| if isinstance(index_type, mstype.Slice): | |||
| slice_positions.append(i) | |||
| slice_positions += (i,) | |||
| elif isinstance(index_type, mstype.Ellipsis_): | |||
| ellipsis_positions.append(i) | |||
| ellipsis_positions += (i,) | |||
| elif isinstance(index_type, mstype.none_type): | |||
| none_positions.append(i) | |||
| none_positions += (i,) | |||
| elif isinstance(index_type, mstype.Int): | |||
| int_positions.append(i) | |||
| int_positions += (i,) | |||
| elif isinstance(index_type, mstype.Bool): | |||
| bool_positions.append(i) | |||
| bool_positions += (i,) | |||
| elif isinstance(index_type, mstype.tensor_type): | |||
| tensor_positions.append(i) | |||
| tensor_positions += (i,) | |||
| elif isinstance(index_type, (list, tuple)): | |||
| sequence_positions.append(i) | |||
| sequence_positions += (i,) | |||
| else: | |||
| raise IndexError(f"For '{op_name}', the index elements only support " | |||
| f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.") | |||
| raise IndexError(f"For '{op_name}', the index elements only support 'Slice', 'Ellipsis', 'None', " | |||
| f"'Tensor', 'int', 'List', 'Tuple', 'bool' but got {index_type}.") | |||
| if len(ellipsis_positions) > 1: | |||
| raise IndexError( | |||
| f"For '{op_name}, an index can only have a single ellipsis('...')") | |||
| @@ -394,8 +388,6 @@ def check_value_elements(data_dtype, types): | |||
| raise TypeError( | |||
| f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.") | |||
| # TODO to del | |||
| @constexpr | |||
| def get_index_tensor_dtype(dtype): | |||
| @@ -408,28 +400,6 @@ def get_index_tensor_dtype(dtype): | |||
| f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.") | |||
| # TODO to del | |||
| @constexpr | |||
| def check_index_tensors_dtype(indexes_types, op_name): | |||
| """Check a tuple of tensor data type.""" | |||
| for index_type in indexes_types: | |||
| if not index_type in (mstype.int32, mstype.int64): | |||
| raise IndexError(f"For '{op_name}', the all index tensor data types should be " | |||
| f"mstype.int32, but got {index_type}.") | |||
| return True | |||
| # TODO to del | |||
| @constexpr | |||
| def check_index_tensor_dtype(index_type, op_name): | |||
| """Check a tensor data type.""" | |||
| if index_type in (mstype.int32, mstype.int64): | |||
| return True | |||
| raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, " | |||
| f"but got {index_type}.") | |||
| # TODO to del | |||
| @constexpr | |||
| def check_tensors_dtype_same(data_dtype, value_dtype, op_name): | |||
| """Check tensors data type same.""" | |||
| @@ -527,31 +497,18 @@ def transform_sequence_index(sequence_index, shape, op_name): | |||
| @constexpr | |||
| def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name): | |||
| def convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, slice_shapes, fancy_position): | |||
| """Convert a slice to a tensor.""" | |||
| shape = [] | |||
| count = 0 | |||
| array = None | |||
| for ele in indexes_shapes_info: | |||
| if isinstance(ele, list): | |||
| if count == slice_number: | |||
| array = np.array(ele, np.int32) | |||
| shape.append(len(ele)) | |||
| else: | |||
| # When the slice is not the slice looking for, the shape is filled with 1. | |||
| shape.append(1) | |||
| count += 1 | |||
| elif isinstance(ele, tuple): | |||
| shape.extend([1] * len(ele)) | |||
| else: | |||
| shape.append(1) | |||
| if array is None: | |||
| raise ValueError( | |||
| f"For '{op_name}', generate tensor from 'slice' failed.") | |||
| shape = [1] * len(slice_shapes) | |||
| shape[slice_cnt] = slice_shapes[slice_cnt] | |||
| shape = shape[:fancy_position] + [1] * len(broadcast_shape) + shape[fancy_position:] | |||
| array = np.array(index, np.int64) | |||
| array = np.reshape(array, shape) | |||
| reps = compute_multiples(shape, final_shape) | |||
| tensor = Tensor(np.tile(array, reps), mstype.int64) | |||
| return tensor | |||
| slice_index_tensor = Tensor(np.tile(array, reps), mstype.int64) | |||
| return slice_index_tensor | |||
| @constexpr | |||
| @@ -599,6 +556,15 @@ def generate_updates_shape(data_shape, index_shape, op_type): | |||
| return updates_shape | |||
| @constexpr | |||
| def transform_slice_to_ele_list(slice_index, dim_len): | |||
| slice_obj = slice(slice_index.start, slice_index.stop, slice_index.step) | |||
| slice_ele_list = list(range(dim_len))[slice_obj] | |||
| if not slice_ele_list: | |||
| raise IndexError(f"An empty slice is not supported, got {slice_obj}") | |||
| return slice_ele_list | |||
| @constexpr | |||
| def check_tuple_index_len(data_rank, tuple_index_len, op_name): | |||
| """Check if the number of index tensor exceeds the dimension of the operated tensor.""" | |||
| @@ -609,89 +575,36 @@ def check_tuple_index_len(data_rank, tuple_index_len, op_name): | |||
| @constexpr | |||
| def generate_index_info_from_tuple_of_mixed_tensors(data_shape, indexes_types, tensor_indexes_shapes, | |||
| tensor_indexes_dtypes, slice_indexes, op_name): | |||
| def generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_indexes_shapes, | |||
| slice_shapes, op_name): | |||
| """ | |||
| Generate index info which contain broadcast shape, final shape, | |||
| indexes shapes info, ellipsis size from a tuple of mixed tensors. | |||
| """ | |||
| check_index_tensors_dtype(tensor_indexes_dtypes, op_name) | |||
| data_rank = len(data_shape) | |||
| indexes_size = len(indexes_types) | |||
| if indexes_size > data_rank: | |||
| raise IndexError(f"For '{op_name}', the number {indexes_size} of index elements " | |||
| f"is greater than the dimension {len(data_shape)} of the operated tensor.") | |||
| indexes_info, index_tensors_info = {}, {} | |||
| tensor_count, slice_count = 0, 0 | |||
| for pos, index_type in enumerate(indexes_types): | |||
| if isinstance(index_type, mstype.tensor_type): | |||
| indexes_info[pos] = tensor_indexes_shapes[tensor_count] | |||
| index_tensors_info[pos] = tensor_indexes_shapes[tensor_count] | |||
| tensor_count += 1 | |||
| elif isinstance(index_type, mstype.Slice): | |||
| slice_obj = slice(slice_indexes[slice_count].start, | |||
| slice_indexes[slice_count].stop, | |||
| slice_indexes[slice_count].step) | |||
| # Use list to represent slicing result. | |||
| indexes_info[pos] = list(range(data_shape[pos]))[slice_obj] | |||
| if not indexes_info[pos]: | |||
| raise IndexError("An empty slice is not supported, got {}:{}:{}".format( | |||
| slice_indexes[slice_count].start, | |||
| slice_indexes[slice_count].stop, | |||
| slice_indexes[slice_count].step)) | |||
| slice_count += 1 | |||
| else: | |||
| raise IndexError(f"For '{op_name}', the index elements only support " | |||
| f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {index_type}.") | |||
| broadcast_shape, final_shape, indexes_shapes_info = _derive_result_shape_info_from_tuple_of_mixed_tensors( | |||
| indexes_info, index_tensors_info, op_name) | |||
| return broadcast_shape, final_shape, indexes_shapes_info | |||
| tensor_positions = tuple(sorted(tensor_positions)) | |||
| tensor_index_continue_tag = _judge_order_continuous(tensor_positions) | |||
| fancy_position = tensor_positions[0] if tensor_index_continue_tag else 0 | |||
| broadcast_shape = generate_broadcast_shape(tensor_indexes_shapes, op_name) | |||
| index_tensor_new_shape, final_shape = [], [] | |||
| if tensor_index_continue_tag: | |||
| final_shape = slice_shapes[:fancy_position] + broadcast_shape + slice_shapes[fancy_position:] | |||
| index_tensor_new_shape = (1,) * len(slice_shapes[:fancy_position]) + \ | |||
| broadcast_shape + (1,) * len(slice_shapes[fancy_position:]) | |||
| def _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key: list): | |||
| """Determine whether the tensor in the index appears continuously.""" | |||
| for i in range(len(index_tensor_info_key) - 1): | |||
| if index_tensor_info_key[i + 1] != index_tensor_info_key[i] + 1: | |||
| return False | |||
| return True | |||
| else: | |||
| final_shape = broadcast_shape + slice_shapes | |||
| index_tensor_new_shape = broadcast_shape + (1,) * len(slice_shapes) | |||
| return broadcast_shape, index_tensor_new_shape, final_shape, fancy_position | |||
| def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name): | |||
| """Derive the resulting shape information from the a tuple index of mixed tensors.""" | |||
| index_tensor_info_key = list(index_tensors_info.keys()) | |||
| index_tensor_info_value = list(index_tensors_info.values()) | |||
| broadcast_shape = generate_broadcast_shape( | |||
| index_tensor_info_value, op_name) | |||
| final_shape, indexes_shapes_info = [], [] | |||
| mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous( | |||
| index_tensor_info_key) | |||
| if mixed_tensors_continuous: | |||
| tensor_shape_dealt = False | |||
| for ele in indexes_info.values(): | |||
| if isinstance(ele, list): | |||
| final_shape.append(len(ele)) | |||
| indexes_shapes_info.append(ele) | |||
| elif isinstance(ele, tuple): | |||
| if not tensor_shape_dealt: | |||
| final_shape.extend(broadcast_shape) | |||
| indexes_shapes_info.append(broadcast_shape) | |||
| tensor_shape_dealt = True | |||
| else: | |||
| raise IndexError(f"For '{op_name}', the index elements only support " | |||
| f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.") | |||
| else: | |||
| final_shape.extend(broadcast_shape) | |||
| indexes_shapes_info.append(broadcast_shape) | |||
| for ele in indexes_info.values(): | |||
| if isinstance(ele, list): | |||
| final_shape.append(len(ele)) | |||
| indexes_shapes_info.append(ele) | |||
| elif isinstance(ele, tuple): | |||
| continue | |||
| else: | |||
| raise IndexError(f"For '{op_name}', the index elements only support " | |||
| f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.") | |||
| return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info) | |||
| def _judge_order_continuous(order_sequence): | |||
| if not order_sequence: | |||
| return False | |||
| for idx1, idx2 in zip(order_sequence[:-1], order_sequence[1:]): | |||
| if idx1 + 1 != idx2: | |||
| return False | |||
| return True | |||
| @constexpr | |||
| @@ -726,20 +639,6 @@ def check_number_index_type(number): | |||
| .format(number, type(number))) | |||
| @constexpr | |||
| def get_stride_info_from_slice(data_shape, slice_index): | |||
| """Get stride info from a python slice""" | |||
| begin, end, step = get_slice_stride(data_shape[0], slice_index) | |||
| begin_strides = [begin] | |||
| end_strides = [end] | |||
| step_strides = [step] | |||
| for end in data_shape[1:]: | |||
| begin_strides.append(0) | |||
| end_strides.append(end) | |||
| step_strides.append(1) | |||
| return tuple(begin_strides), tuple(end_strides), tuple(step_strides) | |||
| @constexpr | |||
| def get_stride_info_from_integer(data_shape, number): | |||
| """Get stride info from a integer""" | |||
| @@ -173,7 +173,7 @@ def _tensor_getitem_by_none(data, none_index): | |||
| Outputs: | |||
| Tensor, element type is as same as the element type of data. | |||
| """ | |||
| return F.expand_dims(data, 0) | |||
| return compile_utils.tensor_index_by_tuple(data, (none_index,)) | |||
| @getitem.register("Tensor", "Slice") | |||
| @@ -188,7 +188,7 @@ def _tensor_getitem_by_slice(data, slice_index): | |||
| Outputs: | |||
| Tensor, element type is the same as the element type of data. | |||
| """ | |||
| return compile_utils.tensor_index_by_slice(data, slice_index) | |||
| return compile_utils.tensor_index_by_tuple(data, (slice_index,)) | |||
| @getitem.register("Tensor", "Tensor") | |||
| @@ -203,7 +203,7 @@ def _tensor_getitem_by_tensor(data, tensor_index): | |||
| Outputs: | |||
| Tensor, element type is the same as the element type of data. | |||
| """ | |||
| return compile_utils.tensor_index_by_tensor(data, tensor_index) | |||
| return compile_utils.tensor_index_by_tuple(data, (tensor_index,)) | |||
| @getitem.register("Tensor", "Ellipsis") | |||
| @@ -218,7 +218,7 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index): | |||
| Outputs: | |||
| Tensor, same as data. | |||
| """ | |||
| return data | |||
| return compile_utils.tensor_index_by_tuple(data, (ellipsis_index,)) | |||
| @getitem.register("Tensor", "List") | |||
| @@ -3011,7 +3011,7 @@ class StridedSlice(PrimitiveWithInfer): | |||
| continue | |||
| if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1': | |||
| if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0: | |||
| raise ValueError(f"For {self.name}, when shrink axis, the stride cannot be negative number, " | |||
| raise IndexError(f"For {self.name}, when shrink axis, the stride cannot be negative number, " | |||
| f"and begin should be in [-{x_shape[i]}, {x_shape[i]}), " | |||
| f"but got stride: {stride}, begin: {begin}.") | |||
| j += 1 | |||
| @@ -155,7 +155,7 @@ class TensorGetItemByThreeTensors(Cell): | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def Xtest_getitem_by_tensors(): | |||
| """This testcase may encounter a sync stream error occassionally""" | |||
| """This testcase may encounter a sync stream error occasionally""" | |||
| net = TensorGetItemByThreeTensors() | |||
| input_x = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32) | |||
| index_0 = np.random.randint(6, size=(3, 4, 5)).astype(np.int32) | |||
| @@ -1024,7 +1024,7 @@ def Xtest_tensor_slice_reduce_out_of_bounds_neg(): | |||
| input_tensor = Tensor(np.ones([6, 8, 10], np.int32)) | |||
| net = NetWork() | |||
| with pytest.raises(ValueError) as ex: | |||
| with pytest.raises(IndexError) as ex: | |||
| net(input_tensor) | |||
| assert "For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`" in str( | |||
| ex.value) | |||
| @@ -1042,7 +1042,7 @@ def Xtest_tensor_slice_reduce_out_of_bounds_positive(): | |||
| input_tensor = Tensor(np.ones([6, 8, 10], np.int32)) | |||
| net = NetWork() | |||
| with pytest.raises(ValueError) as ex: | |||
| with pytest.raises(IndexError) as ex: | |||
| net(input_tensor) | |||
| assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value) | |||
| @@ -1160,7 +1160,7 @@ def test_tensor_slice_reduce_out_of_bounds_neg(): | |||
| input_tensor = Tensor(np.ones([6, 8, 10], np.int32)) | |||
| net = NetWork() | |||
| with pytest.raises(ValueError): | |||
| with pytest.raises(IndexError): | |||
| net(input_tensor) | |||
| @@ -1176,5 +1176,5 @@ def test_tensor_slice_reduce_out_of_bounds_positive(): | |||
| input_tensor = Tensor(np.ones([6, 8, 10], np.int32)) | |||
| net = NetWork() | |||
| with pytest.raises(ValueError): | |||
| with pytest.raises(IndexError): | |||
| net(input_tensor) | |||