Browse Source

fix getitem bug

tags/v1.1.0
Payne 5 years ago
parent
commit
4f78ee0077
4 changed files with 92 additions and 38 deletions
  1. +2
    -0
      mindspore/common/dtype.py
  2. +41
    -19
      mindspore/ops/composite/multitype_ops/_compile_utils.py
  3. +44
    -14
      mindspore/ops/composite/multitype_ops/_constexpr_utils.py
  4. +5
    -5
      tests/ut/python/ops/test_tensor_fancy_index.py

+ 2
- 0
mindspore/common/dtype.py View File

@@ -105,6 +105,8 @@ index_slices = typing.RowTensorType()
sparse_tensor = typing.SparseTensorType()
undetermined = typing.UndeterminedType()
Int = typing.Int
bool_type = typing.Bool
none_type = typing.TypeNone

number_type = (int8,
int16,


+ 41
- 19
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -42,6 +42,27 @@ def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x):
return _broadcast(final_shape, F.reshape(x, new_shape))


def _transform_ellipsis_to_slice(tuple_index, data, op_name):
"""transform ellipsis in the slice to several slice"""
data_shape = F.shape(data)
data_rank = len(data_shape)
indexes_types = hyper_map(F.typeof, tuple_index)
slice_positions, ellipsis_positions, _, int_positions, _, tensor_positions, sequence_positions = \
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
ellipsis_occupy_dims = data_rank - (len(slice_positions) + len(int_positions) +
len(tensor_positions) + len(sequence_positions))

tuple_index_new = ()
for i, index in enumerate(tuple_index):
if i in ellipsis_positions:
for _ in range(ellipsis_occupy_dims):
empty_slice = const_utils.make_empty_slice()
tuple_index_new += (empty_slice,)
else:
tuple_index_new += (index,)
return tuple_index_new


def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple of tensor."""
indices = None
@@ -64,6 +85,7 @@ def _generate_indices_from_tuple(data, tuple_index, op_name):
int_positions, sequence_positions = const_utils.get_pos_of_int_sequence(indexes_types)
tuple_index_new = ()
tuple_len = len(tuple_index)

for i in range(tuple_len):
index = tuple_index[i]
shape = data_shape[i]
@@ -77,15 +99,16 @@ def _generate_indices_from_tuple(data, tuple_index, op_name):
tuple_index_new += (tensor_index,)
else:
tuple_index_new += (index,)

indexes_types = hyper_map(F.typeof, tuple_index_new)
tensor_positions, slice_positions, ellipsis_position = \
const_utils.separate_mixed_tensors_index(indexes_types, op_name)
tensor_indexes = []
slice_indexes = []
tensor_indexes, slice_indexes = [], []
for i in tensor_positions:
tensor_indexes.append(tuple_index_new[i])
for j in slice_positions:
slice_indexes.append(tuple_index_new[j])

tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \
@@ -320,12 +343,12 @@ def tensor_index_by_tensor(data, tensor_index):

def _tensor_index_by_tuple_slice(data, tuple_index):
"""Tensor getitem by a tuple of slice"""
shape = F.shape(data)
if len(tuple_index) > len(shape):
const_utils.raise_index_error("When tensor is indexed by a tuple, "
"the length of the tuple cannot be greater than the dimension of the tensor.")
data_shape = F.shape(data)
if len(tuple_index) > len(data_shape):
const_utils.raise_index_error("When tensor is indexed by a tuple, the length of the tuple cannot "
"be greater than the dimension of the tensor.")
begin_strides, end_strides, step_strides, shrink_axis_mask = \
const_utils.get_stride_info_from_tuple(shape, tuple_index)
const_utils.get_stride_info_from_tuple(data_shape, tuple_index)
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)


@@ -340,16 +363,15 @@ def tensor_index_by_list(data, list_index):

def tensor_index_by_tuple(data, tuple_index):
"""Tensor getitem by tuple of various types with None"""
tuple_index_without_none = tuple_index
if len(tuple_index) == 1:
return data[tuple_index_without_none[0]]
indexes_types = hyper_map(F.typeof, tuple_index_without_none)
return data[tuple_index[0]]
indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_GETITEM)
if contain_type == const_utils.ALL_TENSOR:
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
if contain_type == const_utils.ALL_BASIC:
return _tensor_index_by_tuple_slice(data, tuple_index)
return _tensor_getitem_by_tuple(data, tuple_index_without_none)
return _tensor_getitem_by_tuple(data, tuple_index)


def _tensor_setitem(self, index, value):
@@ -456,7 +478,7 @@ def tensor_setitem_by_tensor_with_number(data, index, value):


def tensor_setitem_by_tensor_with_tuple(data, index, value):
"""Assigns the tensor by tensor with tuple value."""
"""Assigns the tensor by tensor with tuple value."""
index_dtype = F.dtype(index)
check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM)
result = None
@@ -505,7 +527,7 @@ def tensor_setitem_by_slice_with_number(data, input_slice, value):


def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
"""Assigns the tensor by tuple with number value."""
"""Assigns the tensor by tuple with number value."""
if len(tuple_index) == 1:
data[tuple_index[0]] = value
return data
@@ -570,7 +592,7 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):


def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
"""Assigns the tensor by tuple with tensor value."""
"""Assigns the tensor by tuple with tensor value."""
if len(tuple_index) == 1:
data[tuple_index[0]] = value
return data
@@ -614,7 +636,7 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):


def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
"""Assigns the tensor by tuple with tuple of value."""
"""Assigns the tensor by tuple with tuple of value."""
if len(tuple_index) == 1:
data[tuple_index[0]] = value
return data
@@ -640,28 +662,28 @@ def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):


def tensor_setitem_by_number_with_number(data, index, value):
"""Assigns the tensor by number with number value."""
"""Assigns the tensor by number with number value."""
data_shape = F.shape(data)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_number(data, data_shape, index, indices, value)


def tensor_setitem_by_number_with_tensor(data, index, value):
"""Assigns the tensor by number with tensor value."""
"""Assigns the tensor by number with tensor value."""
data_shape = F.shape(data)
indices = const_utils.integer_to_indices(index, data_shape)
return _tensor_indices_tensor(data, data_shape, index, indices, value)


def tensor_setitem_by_ellipsis_with_number(data, index, value):
"""Assigns the tensor by ellipsis with number value."""
"""Assigns the tensor by ellipsis with number value."""
data_shape = F.shape(data)
data_dtype = F.dtype(data)
return F.fill(data_dtype, data_shape, value)


def tensor_setitem_by_ellipsis_with_tensor(data, index, value):
"""Assigns the tensor by ellipsis with tensor value."""
"""Assigns the tensor by ellipsis with tensor value."""
result = None
data_shape = F.shape(data)
data_dtype = F.dtype(data)


+ 44
- 14
mindspore/ops/composite/multitype_ops/_constexpr_utils.py View File

@@ -734,6 +734,12 @@ def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_te
return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info)


@constexpr
def make_empty_slice():
empty_slice = slice(None, None, None)
return empty_slice


@constexpr
def get_pos_of_int_index(indexes_types):
"""Get int index positions from the mixed tensors index which contains int, tensor, slice, and ellipsis."""
@@ -770,12 +776,40 @@ def separate_mixed_tensors_index(indexes_types, op_name):
elif isinstance(ele_type, mstype.ellipsis_type):
ellipsis_position = i
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {ele_type}.")
raise TypeError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {ele_type}.")

return tensor_positions, slice_positions, ellipsis_position


@constexpr
def get_pos_of_indexes_types(indexes_types, op_name):
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, tensor_positions, \
sequence_positions = [], [], [], [], [], [], []
for i, index_type in enumerate(indexes_types):
if isinstance(index_type, mstype.slice_type):
slice_positions.append(i)
elif isinstance(index_type, mstype.ellipsis_type):
ellipsis_positions.append(i)
elif isinstance(index_type, mstype.none_type):
none_positions.append(i)
elif isinstance(index_type, mstype.Int):
int_positions.append(i)
elif isinstance(index_type, mstype.bool_type):
bool_positions.append(i)
elif isinstance(index_type, mstype.tensor_type):
tensor_positions.append(i)
elif isinstance(index_type, (list, tuple)):
sequence_positions.append(i)
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.")

return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \
tensor_positions, sequence_positions


@constexpr
def scalar_in_sequence(x, y):
"""Determine whether the scalar in the sequence."""
@@ -849,17 +883,13 @@ def get_slice_stride(dim_size, index_slice):


@constexpr
def get_stride_info_from_tuple(data_shape, index_tuple):
def get_stride_info_from_tuple(data_shape, tuple_index):
"""Get stride info from a tuple"""
begin_strides = []
end_strides = []
step_strides = []
index_size = len(index_tuple)
data_shape_size = len(data_shape)
shrink_axis = 0
index_count = 0
ellipsis_count = 0
for idx, item in enumerate(index_tuple):
begin_strides, end_strides, step_strides = [], [], []
tuple_index_len = len(tuple_index)
data_rank = len(data_shape)
shrink_axis, index_count, ellipsis_count = 0, 0, 0
for idx, item in enumerate(tuple_index):
if isinstance(item, slice):
start, stop, step = get_slice_stride(data_shape[idx], item)
begin_strides.append(start)
@@ -876,7 +906,7 @@ def get_stride_info_from_tuple(data_shape, index_tuple):
ellipsis_count = ellipsis_count + 1
if ellipsis_count > 1:
raise IndexError("An index can have only one ellipsis (...)")
ellipsis_range_size = data_shape_size - (index_size - 1)
ellipsis_range_size = data_rank - (tuple_index_len - 1)
begin_strides.extend([0] * (ellipsis_range_size))
end_strides.extend(
[i for i in data_shape[index_count: index_count + (ellipsis_range_size)]])
@@ -885,7 +915,7 @@ def get_stride_info_from_tuple(data_shape, index_tuple):
else:
raise IndexError("Not supported index data type, got ",
item, " type is ", type(item))
for item in range(index_count, data_shape_size):
for item in range(index_count, data_rank):
begin_strides.append(0)
end_strides.append(data_shape[item])
step_strides.append(1)


+ 5
- 5
tests/ut/python/ops/test_tensor_fancy_index.py View File

@@ -30,7 +30,7 @@ class NetWorkFancyIndex(Cell):
return tensor[self.index]


def test_tensor_fancy_index_integer_list_graph():
def test_tensor_fancy_index_integer_list():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = [0, 2, 1]
net = NetWorkFancyIndex(index)
@@ -39,7 +39,7 @@ def test_tensor_fancy_index_integer_list_graph():
net(input_me)


def test_tensor_fancy_boolean_list_graph():
def test_tensor_fancy_index_boolean_list():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = [True, True, False]
net = NetWorkFancyIndex(index)
@@ -57,7 +57,7 @@ def test_tensor_fancy_integer_boolean_list_graph():
net(input_me)


def test_tensor_fancy_integer_list_mixed_graph():
def test_tensor_fancy_integer_list_mixed():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = (1, [2, 1, 3], slice(1, 3, 1), ..., 4)
net = NetWorkFancyIndex(index)
@@ -66,7 +66,7 @@ def test_tensor_fancy_integer_list_mixed_graph():
net(input_me)


def test_tensor_fancy_integer_tuple_mixed_graph():
def test_tensor_fancy_integer_tuple_mixed():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = (1, (2, 1, 3), slice(1, 3, 1), ..., 4)
net = NetWorkFancyIndex(index)
@@ -75,7 +75,7 @@ def test_tensor_fancy_integer_tuple_mixed_graph():
net(input_me)


def test_tensor_fancy_integer_list_tuple_mixed_graph():
def test_tensor_fancy_integer_list_tuple_mixed():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = (1, [2, 1, 3], (3, 2, 1), slice(1, 3, 1), ..., 4)
net = NetWorkFancyIndex(index)


Loading…
Cancel
Save