Browse Source

!12714 rebuild getitem

From: @yepei6
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
e1a95b1152
6 changed files with 126 additions and 287 deletions
  1. +54
    -114
      mindspore/ops/composite/multitype_ops/_compile_utils.py
  2. +62
    -163
      mindspore/ops/composite/multitype_ops/_constexpr_utils.py
  3. +4
    -4
      mindspore/ops/composite/multitype_ops/getitem_impl.py
  4. +1
    -1
      mindspore/ops/operations/array_ops.py
  5. +3
    -3
      tests/st/pynative/test_tensor_index.py
  6. +2
    -2
      tests/ut/python/ops/test_tensor_slice.py

+ 54
- 114
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -28,23 +28,12 @@ pack = P.Stack(axis=-1)

def _tensor_getitem(self, index):
"""Handle tensor getitem"""
if isinstance(index, Tensor):
return tensor_index_by_tensor(self, index)
if isinstance(index, list):
return tensor_index_by_list(self, index)
if isinstance(index, tuple):
return tensor_index_by_tuple(self, index)
# bool type should be judged before int
if isinstance(index, bool):
return _tensor_index_by_bool(self, index)
if isinstance(index, int):
return _tensor_index_by_integer(self, index)
if isinstance(index, slice):
return tensor_index_by_slice(self, index)
if index is None:
return F.expand_dims(self, 0)
if index is ...:
return self
if isinstance(index, (Tensor, int, slice)) or index in (None, ...):
return tensor_index_by_tuple(self, (index,))
raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor with int, "
f"list and tuple ,but got {index} with type {type(index)}.")

@@ -149,17 +138,7 @@ def _expand_data_dims(data, tuple_index):
return data, tuple_index_new


def tensor_index_by_slice(data, slice_index):
"""Tensor getitem by a single slice"""
shape = F.shape(data)
if not shape:
const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor"
"cannot be 0.")
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(shape, slice_index)
return F.strided_slice(data, begin_strides, end_strides, step_strides)


def tensor_index_by_number(data, number):
def tensor_index_by_number(data, number_index):
"""Tensor getitem by a Number which may be integer/float/bool value"""
data_type = F.typeof(data)
if const_utils.judge_index_type(data_type, mstype.tensor_type):
@@ -168,42 +147,35 @@ def tensor_index_by_number(data, number):
min_data_rank, max_data_rank = 0, 8
const_utils.judge_data_rank(data_rank, min_data_rank, max_data_rank)

number_type = const_utils.check_number_index_type(number)
number_type = const_utils.check_number_index_type(number_index)
if number_type == const_utils.BOOL_:
return _tensor_index_by_bool(data, number)
return tensor_index_by_tuple(data, (number_index,))
if number_type == const_utils.INT_:
return _tensor_index_by_integer(data, number)
return _tensor_index_by_integer(data, number_index)
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.")


def _tensor_index_by_bool(data, bool_value):
# TODO wait to remove after setitem by Yang Linfeng
def _tensor_index_by_bool(data, bool_index):
"""Tensor getitem by a single bool value"""
if bool_value:
if bool_index:
return F.expand_dims(data, 0)
return const_utils.make_tensor([], data.dtype, (0,) + F.shape(data))


def _tensor_index_by_integer(data, number):
def _tensor_index_by_integer(data, int_index):
"""Tensor getitem by a single integer number"""
data_shape = F.shape(data)
data_rank = len(data_shape)
if data_rank == 0:
return const_utils.raise_type_error("When tensor is indexed by an integer, the dimension of the tensor "
"cannot be 0.")
transformed_number = const_utils.check_and_transform_int_index(number, data_shape[0], const_utils.TENSOR_GETITEM)
transformed_number = const_utils.check_and_transform_int_index(int_index, data_shape[0], const_utils.TENSOR_GETITEM)
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(data_shape, transformed_number)
shrink_axis_mask = 1
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)


def tensor_index_by_tensor(data, tensor_index):
"""Tensor getitem by a single tensor"""
index_type = F.dtype(tensor_index)
const_utils.check_index_type_valid(index_type, mstype.int_type, const_utils.TENSOR_GETITEM)
tensor_index = F.cast(tensor_index, mstype.int64)
return F.gather(data, tensor_index, 0)


def tensor_index_by_list(data, list_index):
"""Tensor getitem by list of int and bool"""
data_shape = F.shape(data)
@@ -241,39 +213,11 @@ def tensor_index_by_tuple(data, tuple_index):

indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)

if contain_type == const_utils.ALL_TENSOR:
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name)
if contain_type == const_utils.ALL_BASIC:
return _tensor_getitem_by_tuple_slice(data, tuple_index)
return _tensor_getitem_by_tuple(data, tuple_index, op_name)


def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name):
"""Tensor getitem by a tuple of tensor."""
data_shape = F.shape(data)
tuple_index_len = len(tuple_index)

indexes_types = hyper_map(F.dtype, tuple_index)
const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name)
tensor_index_shape = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
if 0 in broadcast_shape:
res_shape = broadcast_shape
if tuple_index_len < len(data_shape):
res_shape += data_shape[tuple_index_len:]
res = const_utils.make_tensor([], data.dtype, res_shape)
return res

broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
new_broadcast_tensors = ()
for tensor in broadcast_tensors:
new_broadcast_tensors += (F.cast(tensor, mstype.int64),)
indices = pack(new_broadcast_tensors)
result = F.gather_nd(data, indices)
return result


def _tensor_getitem_by_tuple_slice(data, tuple_index):
"""Tensor getitem by a tuple of slice"""
data_shape = F.shape(data)
@@ -291,7 +235,7 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name):
indexes_types = hyper_map(F.typeof, tuple_index)
slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
tuple_index_new = ()
tuple_index_new, slice_shapes = (), ()

for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
if i in int_positions:
@@ -299,57 +243,56 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name):
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
tensor_positions.append(i)
tensor_positions += (i,)
elif i in sequence_positions:
sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name)
tensor_index = const_utils.make_tensor(sequence_index)
tensor_index = F.cast(tensor_index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
tensor_positions.append(i)
tensor_positions += (i,)
elif i in tensor_positions:
const_utils.check_index_type_valid(F.dtype(index), mstype.int_type, op_name)
const_utils.check_type_valid(F.dtype(index), mstype.int_type, op_name)
tensor_index = F.cast(index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
elif i in slice_positions:
slice_indexes.append(index)
tuple_index_new += (index,)
slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size)
slice_shapes += (len(slice_ele_list_index),)
tuple_index_new += (slice_ele_list_index,)
slice_indexes.append(slice_ele_list_index)

tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
indexes_types = hyper_map(F.typeof, tuple_index_new)
broadcast_shape, final_shape, indexes_shapes_info = const_utils.generate_index_info_from_tuple_of_mixed_tensors(
data_shape, indexes_types, tensor_indexes_shapes, tensor_indexes_dtypes, slice_indexes, op_name)
broadcast_shape, index_tensor_new_shape, final_shape, fancy_position = \
const_utils.generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_indexes_shapes,
slice_shapes, op_name)

if 0 in final_shape:
if 0 in final_shape + data_shape:
if tuple_index_len < data_rank:
final_shape = final_shape + data_shape[tuple_index_len:]
return const_utils.make_tensor([], data.dtype, final_shape)

slice_number = 0
final_index_tensors = []
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_len):
slice_cnt = 0
for i, index in enumerate(tuple_index_new):
if i in tensor_positions:
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
tuple_index_new[i])
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape, index)
final_index_tensors.append(transform_tensor)
if i in slice_positions:
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
final_index_tensors.append(slice_tensor)
slice_number += 1
elif i in slice_positions:
slice_index_tensor = const_utils.convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape,
slice_shapes, fancy_position)
final_index_tensors.append(slice_index_tensor)
slice_cnt += 1

indices = pack(final_index_tensors)
result = F.gather_nd(data, indices)
return result
return F.gather_nd(data, indices)


def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
"""Generate an indices tensor from a tuple of tensor."""
indices = None
indexes_types = hyper_map(F.dtype, tuple_index)
const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name)
const_utils.check_types_valid(indexes_types, mstype.int_type, op_name)
tensor_index_shape = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
@@ -363,12 +306,11 @@ def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
def _generate_indices_from_tuple(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
data_shape = F.shape(data)
tuple_index_len = len(tuple_index)
tensor_indexes, slice_indexes = [], []
indexes_types = hyper_map(F.typeof, tuple_index)
slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
tuple_index_new = ()
tuple_index_new, slice_shapes = (), ()

for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
if i in int_positions:
@@ -376,41 +318,41 @@ def _generate_indices_from_tuple(data, tuple_index, op_name):
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
tensor_positions.append(i)
tensor_positions += (i,)
elif i in sequence_positions:
sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name)
tensor_index = const_utils.make_tensor(sequence_index)
tensor_index = F.cast(tensor_index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
tensor_positions.append(i)
tensor_positions += (i,)
elif i in tensor_positions:
const_utils.check_index_type_valid(F.dtype(index), mstype.int_type, op_name)
const_utils.check_type_valid(F.dtype(index), mstype.int_type, op_name)
tensor_index = F.cast(index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
elif i in slice_positions:
slice_indexes.append(index)
tuple_index_new += (index,)
slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size)
slice_shapes += (len(slice_ele_list_index),)
tuple_index_new += (slice_ele_list_index,)
slice_indexes.append(slice_ele_list_index)

tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
indexes_types = hyper_map(F.typeof, tuple_index_new)
broadcast_shape, final_shape, indexes_shapes_info = const_utils.generate_index_info_from_tuple_of_mixed_tensors(
data_shape, indexes_types, tensor_indexes_shapes, tensor_indexes_dtypes, slice_indexes, op_name)
broadcast_shape, index_tensor_new_shape, final_shape, fancy_position = \
const_utils.generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_indexes_shapes,
slice_shapes, op_name)

slice_number = 0
final_index_tensors = []
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_len):
slice_cnt = 0
for i, index in enumerate(tuple_index_new):
if i in tensor_positions:
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
tuple_index_new[i])
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape, index)
final_index_tensors.append(transform_tensor)
if i in slice_positions:
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
final_index_tensors.append(slice_tensor)
slice_number += 1
elif i in slice_positions:
slice_index_tensor = const_utils.convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape,
slice_shapes, fancy_position)
final_index_tensors.append(slice_index_tensor)
slice_cnt += 1

indices = pack(final_index_tensors)
return indices
@@ -530,10 +472,8 @@ def tensor_setitem_by_tensor_with_number(data, index, value):
def tensor_setitem_by_tensor_with_tuple(data, index, value):
"""Assigns the tensor by tensor with tuple value."""
index_dtype = F.dtype(index)
check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM)
result = None
if check_dtype:
result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
const_utils.check_type_valid(index_dtype, (mstype.int32, mstype.int64), const_utils.TENSOR_SETITEM)
result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
return result




+ 62
- 163
mindspore/ops/composite/multitype_ops/_constexpr_utils.py View File

@@ -151,20 +151,6 @@ def judge_index_type(index_type, target_type):
return False


@constexpr
def check_type_valid(dtype, target_type, op_name):
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
raise TypeError(
f"The '{op_name}' doesn't supoort {dtype}' and expect to receive {target_type}.")


@constexpr
def check_index_type_valid(dtype, target_type, op_name):
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
raise IndexError(
f"The '{op_name}' doesn't supoort {dtype}' and expect to receive {target_type}.")


@constexpr
def judge_indexes_types(dtypes, target_type):
"""Check a tuple of tensor data type."""
@@ -175,37 +161,45 @@ def judge_indexes_types(dtypes, target_type):


@constexpr
def check_indexes_types_valid(dtypes, target_type, op_name):
def check_type_valid(dtype, target_type, op_name):
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
if op_name in (TENSOR_GETITEM, TENSOR_SETITEM):
raise IndexError(
f"The '{op_name}' doesn't supoort {dtype}' and expect to receive {target_type}.")
raise TypeError(
f"The '{op_name}' doesn't supoort {dtype}' and expect to receive {target_type}.")


@constexpr
def check_types_valid(dtypes, target_type, op_name):
"""Check a tuple of tensor data type."""
for dtype in dtypes:
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
raise IndexError(f"For '{op_name}', the all index tensor data types should be in {target_type}, "
f"but got {dtype}.")
check_type_valid(dtype, target_type, op_name)


@constexpr
def get_pos_of_indexes_types(indexes_types, op_name):
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, tensor_positions, \
sequence_positions = [], [], [], [], [], [], []
sequence_positions = (), (), (), (), (), (), ()
for i, index_type in enumerate(indexes_types):
if isinstance(index_type, mstype.Slice):
slice_positions.append(i)
slice_positions += (i,)
elif isinstance(index_type, mstype.Ellipsis_):
ellipsis_positions.append(i)
ellipsis_positions += (i,)
elif isinstance(index_type, mstype.none_type):
none_positions.append(i)
none_positions += (i,)
elif isinstance(index_type, mstype.Int):
int_positions.append(i)
int_positions += (i,)
elif isinstance(index_type, mstype.Bool):
bool_positions.append(i)
bool_positions += (i,)
elif isinstance(index_type, mstype.tensor_type):
tensor_positions.append(i)
tensor_positions += (i,)
elif isinstance(index_type, (list, tuple)):
sequence_positions.append(i)
sequence_positions += (i,)
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.")
raise IndexError(f"For '{op_name}', the index elements only support 'Slice', 'Ellipsis', 'None', "
f"'Tensor', 'int', 'List', 'Tuple', 'bool' but got {index_type}.")
if len(ellipsis_positions) > 1:
raise IndexError(
f"For '{op_name}, an index can only have a single ellipsis('...')")
@@ -394,8 +388,6 @@ def check_value_elements(data_dtype, types):
raise TypeError(
f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.")

# TODO to del


@constexpr
def get_index_tensor_dtype(dtype):
@@ -408,28 +400,6 @@ def get_index_tensor_dtype(dtype):
f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")


# TODO to del
@constexpr
def check_index_tensors_dtype(indexes_types, op_name):
"""Check a tuple of tensor data type."""
for index_type in indexes_types:
if not index_type in (mstype.int32, mstype.int64):
raise IndexError(f"For '{op_name}', the all index tensor data types should be "
f"mstype.int32, but got {index_type}.")
return True


# TODO to del
@constexpr
def check_index_tensor_dtype(index_type, op_name):
"""Check a tensor data type."""
if index_type in (mstype.int32, mstype.int64):
return True
raise IndexError(f"For '{op_name}', the index tensor data type should be mstype.int32, "
f"but got {index_type}.")


# TODO to del
@constexpr
def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
"""Check tensors data type same."""
@@ -527,31 +497,18 @@ def transform_sequence_index(sequence_index, shape, op_name):


@constexpr
def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name):
def convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, slice_shapes, fancy_position):
"""Convert a slice to a tensor."""
shape = []
count = 0
array = None
for ele in indexes_shapes_info:
if isinstance(ele, list):
if count == slice_number:
array = np.array(ele, np.int32)
shape.append(len(ele))
else:
# When the slice is not the slice looking for, the shape is filled with 1.
shape.append(1)
count += 1
elif isinstance(ele, tuple):
shape.extend([1] * len(ele))
else:
shape.append(1)
if array is None:
raise ValueError(
f"For '{op_name}', generate tensor from 'slice' failed.")

shape = [1] * len(slice_shapes)
shape[slice_cnt] = slice_shapes[slice_cnt]
shape = shape[:fancy_position] + [1] * len(broadcast_shape) + shape[fancy_position:]

array = np.array(index, np.int64)
array = np.reshape(array, shape)
reps = compute_multiples(shape, final_shape)
tensor = Tensor(np.tile(array, reps), mstype.int64)
return tensor
slice_index_tensor = Tensor(np.tile(array, reps), mstype.int64)
return slice_index_tensor


@constexpr
@@ -599,6 +556,15 @@ def generate_updates_shape(data_shape, index_shape, op_type):
return updates_shape


@constexpr
def transform_slice_to_ele_list(slice_index, dim_len):
slice_obj = slice(slice_index.start, slice_index.stop, slice_index.step)
slice_ele_list = list(range(dim_len))[slice_obj]
if not slice_ele_list:
raise IndexError(f"An empty slice is not supported, got {slice_obj}")
return slice_ele_list


@constexpr
def check_tuple_index_len(data_rank, tuple_index_len, op_name):
"""Check if the number of index tensor exceeds the dimension of the operated tensor."""
@@ -609,89 +575,36 @@ def check_tuple_index_len(data_rank, tuple_index_len, op_name):


@constexpr
def generate_index_info_from_tuple_of_mixed_tensors(data_shape, indexes_types, tensor_indexes_shapes,
tensor_indexes_dtypes, slice_indexes, op_name):
def generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_indexes_shapes,
slice_shapes, op_name):
"""
Generate index info which contain broadcast shape, final shape,
indexes shapes info, ellipsis size from a tuple of mixed tensors.
"""
check_index_tensors_dtype(tensor_indexes_dtypes, op_name)
data_rank = len(data_shape)
indexes_size = len(indexes_types)
if indexes_size > data_rank:
raise IndexError(f"For '{op_name}', the number {indexes_size} of index elements "
f"is greater than the dimension {len(data_shape)} of the operated tensor.")
indexes_info, index_tensors_info = {}, {}
tensor_count, slice_count = 0, 0
for pos, index_type in enumerate(indexes_types):
if isinstance(index_type, mstype.tensor_type):
indexes_info[pos] = tensor_indexes_shapes[tensor_count]
index_tensors_info[pos] = tensor_indexes_shapes[tensor_count]
tensor_count += 1
elif isinstance(index_type, mstype.Slice):
slice_obj = slice(slice_indexes[slice_count].start,
slice_indexes[slice_count].stop,
slice_indexes[slice_count].step)
# Use list to represent slicing result.
indexes_info[pos] = list(range(data_shape[pos]))[slice_obj]
if not indexes_info[pos]:
raise IndexError("An empty slice is not supported, got {}:{}:{}".format(
slice_indexes[slice_count].start,
slice_indexes[slice_count].stop,
slice_indexes[slice_count].step))
slice_count += 1
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {index_type}.")
broadcast_shape, final_shape, indexes_shapes_info = _derive_result_shape_info_from_tuple_of_mixed_tensors(
indexes_info, index_tensors_info, op_name)
return broadcast_shape, final_shape, indexes_shapes_info
tensor_positions = tuple(sorted(tensor_positions))
tensor_index_continue_tag = _judge_order_continuous(tensor_positions)
fancy_position = tensor_positions[0] if tensor_index_continue_tag else 0
broadcast_shape = generate_broadcast_shape(tensor_indexes_shapes, op_name)
index_tensor_new_shape, final_shape = [], []

if tensor_index_continue_tag:
final_shape = slice_shapes[:fancy_position] + broadcast_shape + slice_shapes[fancy_position:]
index_tensor_new_shape = (1,) * len(slice_shapes[:fancy_position]) + \
broadcast_shape + (1,) * len(slice_shapes[fancy_position:])

def _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key: list):
"""Determine whether the tensor in the index appears continuously."""
for i in range(len(index_tensor_info_key) - 1):
if index_tensor_info_key[i + 1] != index_tensor_info_key[i] + 1:
return False
return True
else:
final_shape = broadcast_shape + slice_shapes
index_tensor_new_shape = broadcast_shape + (1,) * len(slice_shapes)

return broadcast_shape, index_tensor_new_shape, final_shape, fancy_position

def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_tensors_info, op_name):
"""Derive the resulting shape information from the a tuple index of mixed tensors."""
index_tensor_info_key = list(index_tensors_info.keys())
index_tensor_info_value = list(index_tensors_info.values())
broadcast_shape = generate_broadcast_shape(
index_tensor_info_value, op_name)
final_shape, indexes_shapes_info = [], []
mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous(
index_tensor_info_key)
if mixed_tensors_continuous:
tensor_shape_dealt = False
for ele in indexes_info.values():
if isinstance(ele, list):
final_shape.append(len(ele))
indexes_shapes_info.append(ele)
elif isinstance(ele, tuple):
if not tensor_shape_dealt:
final_shape.extend(broadcast_shape)
indexes_shapes_info.append(broadcast_shape)
tensor_shape_dealt = True
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.")
else:
final_shape.extend(broadcast_shape)
indexes_shapes_info.append(broadcast_shape)
for ele in indexes_info.values():
if isinstance(ele, list):
final_shape.append(len(ele))
indexes_shapes_info.append(ele)
elif isinstance(ele, tuple):
continue
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {type(ele).__name__}.")
return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info)
def _judge_order_continuous(order_sequence):
if not order_sequence:
return False
for idx1, idx2 in zip(order_sequence[:-1], order_sequence[1:]):
if idx1 + 1 != idx2:
return False
return True


@constexpr
@@ -726,20 +639,6 @@ def check_number_index_type(number):
.format(number, type(number)))


@constexpr
def get_stride_info_from_slice(data_shape, slice_index):
"""Get stride info from a python slice"""
begin, end, step = get_slice_stride(data_shape[0], slice_index)
begin_strides = [begin]
end_strides = [end]
step_strides = [step]
for end in data_shape[1:]:
begin_strides.append(0)
end_strides.append(end)
step_strides.append(1)
return tuple(begin_strides), tuple(end_strides), tuple(step_strides)


@constexpr
def get_stride_info_from_integer(data_shape, number):
"""Get stride info from a integer"""


+ 4
- 4
mindspore/ops/composite/multitype_ops/getitem_impl.py View File

@@ -173,7 +173,7 @@ def _tensor_getitem_by_none(data, none_index):
Outputs:
Tensor, element type is as same as the element type of data.
"""
return F.expand_dims(data, 0)
return compile_utils.tensor_index_by_tuple(data, (none_index,))


@getitem.register("Tensor", "Slice")
@@ -188,7 +188,7 @@ def _tensor_getitem_by_slice(data, slice_index):
Outputs:
Tensor, element type is the same as the element type of data.
"""
return compile_utils.tensor_index_by_slice(data, slice_index)
return compile_utils.tensor_index_by_tuple(data, (slice_index,))


@getitem.register("Tensor", "Tensor")
@@ -203,7 +203,7 @@ def _tensor_getitem_by_tensor(data, tensor_index):
Outputs:
Tensor, element type is the same as the element type of data.
"""
return compile_utils.tensor_index_by_tensor(data, tensor_index)
return compile_utils.tensor_index_by_tuple(data, (tensor_index,))


@getitem.register("Tensor", "Ellipsis")
@@ -218,7 +218,7 @@ def _tensor_getitem_by_ellipsis(data, ellipsis_index):
Outputs:
Tensor, same as data.
"""
return data
return compile_utils.tensor_index_by_tuple(data, (ellipsis_index,))


@getitem.register("Tensor", "List")


+ 1
- 1
mindspore/ops/operations/array_ops.py View File

@@ -3011,7 +3011,7 @@ class StridedSlice(PrimitiveWithInfer):
continue
if j < len(shrink_axis_pos) and shrink_axis_pos[j] == '1':
if (not -x_shape[i] <= begin < x_shape[i]) or stride < 0:
raise ValueError(f"For {self.name}, when shrink axis, the stride cannot be negative number, "
raise IndexError(f"For {self.name}, when shrink axis, the stride cannot be negative number, "
f"and begin should be in [-{x_shape[i]}, {x_shape[i]}), "
f"but got stride: {stride}, begin: {begin}.")
j += 1


+ 3
- 3
tests/st/pynative/test_tensor_index.py View File

@@ -155,7 +155,7 @@ class TensorGetItemByThreeTensors(Cell):
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def Xtest_getitem_by_tensors():
"""This testcase may encounter a sync stream error occassionally"""
"""This testcase may encounter a sync stream error occasionally"""
net = TensorGetItemByThreeTensors()
input_x = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32)
index_0 = np.random.randint(6, size=(3, 4, 5)).astype(np.int32)
@@ -1024,7 +1024,7 @@ def Xtest_tensor_slice_reduce_out_of_bounds_neg():

input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
net = NetWork()
with pytest.raises(ValueError) as ex:
with pytest.raises(IndexError) as ex:
net(input_tensor)
assert "For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`" in str(
ex.value)
@@ -1042,7 +1042,7 @@ def Xtest_tensor_slice_reduce_out_of_bounds_positive():

input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
net = NetWork()
with pytest.raises(ValueError) as ex:
with pytest.raises(IndexError) as ex:
net(input_tensor)
assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value)



+ 2
- 2
tests/ut/python/ops/test_tensor_slice.py View File

@@ -1160,7 +1160,7 @@ def test_tensor_slice_reduce_out_of_bounds_neg():

input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
net = NetWork()
with pytest.raises(ValueError):
with pytest.raises(IndexError):
net(input_tensor)


@@ -1176,5 +1176,5 @@ def test_tensor_slice_reduce_out_of_bounds_positive():

input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
net = NetWork()
with pytest.raises(ValueError):
with pytest.raises(IndexError):
net(input_tensor)

Loading…
Cancel
Save