Browse Source

!10853 getitem: bool index expand dims

From: @yepei6
Reviewed-by: @kisnwang,@kingxian
Signed-off-by: @kingxian
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
57f4d6b4c9
4 changed files with 139 additions and 85 deletions
  1. +1
    -0
      mindspore/common/dtype.py
  2. +37
    -9
      mindspore/ops/composite/multitype_ops/_compile_utils.py
  3. +77
    -72
      mindspore/ops/composite/multitype_ops/_constexpr_utils.py
  4. +24
    -4
      tests/ut/python/ops/test_tensor_fancy_index.py

+ 1
- 0
mindspore/common/dtype.py View File

@@ -93,6 +93,7 @@ env_type = typing.EnvType()
env_type_type = typing.EnvType
type_type = typing.TypeType()
type_none = typing.TypeNone()
type_bool = typing.Bool()
string = typing.String()
type_refkey = typing.RefKeyType()
tensor_type = typing.TensorType


+ 37
- 9
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -137,13 +137,37 @@ def _expand_data_dims_with_none(data, tuple_index, op_name):
none_type_tag = const_utils.judge_index_type(index_type, mstype.type_none)
tuple_index_without_none += (const_utils.make_empty_slice(),) if none_type_tag else(index,)
none_positions += (i,) if none_type_tag else ()

for dim in none_positions:
data = F.expand_dims(data, dim)

return data, tuple_index_without_none


def _expand_data_dims_with_bool(data, tuple_index, op_name):
"""expand the data's dim with 'True/False' in tuple_index"""
indexes_types = hyper_map(F.typeof, tuple_index)
bool_positions, tuple_index_without_bool = (), ()

for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)):
bool_type_tag = const_utils.judge_index_type(index_type, mstype.type_bool)
if bool_type_tag:
if index:
tuple_index_without_bool += (const_utils.make_tensor([0], mstype.int64),)
else:
# todo wait to complete the operations' support for zero dim-size, then could make 0 length tensor.
# to replace the 'False'

return const_utils.raise_index_error("When tensor is indexed by a tuple which contains bool object, "
"the value only support 'True'.")
else:
tuple_index_without_bool += (index,)
bool_positions += (i,) if bool_type_tag else ()

for dim in bool_positions:
data = F.expand_dims(data, dim)

return data, tuple_index_without_bool


def tensor_index_by_slice(data, slice_index):
"""Tensor getitem by a single slice"""
shape = F.shape(data)
@@ -168,7 +192,7 @@ def _tensor_index_by_bool(data, bool_value):
"""Tensor getitem by a single bool value"""
if bool_value:
return F.expand_dims(data, 0)
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
return const_utils.make_tensor([], data.dtype, (0,) + F.shape(data))


def _tensor_index_by_integer(data, number):
@@ -207,8 +231,11 @@ def tensor_index_by_tuple(data, tuple_index):
op_name = const_utils.TENSOR_GETITEM
if len(tuple_index) == 1:
return data[tuple_index[0]]

tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims_with_bool(data, tuple_index, op_name)

indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
if contain_type == const_utils.ALL_TENSOR:
@@ -228,8 +255,8 @@ def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
def _tensor_getitem_by_tuple_slice(data, tuple_index):
"""Tensor getitem by a tuple of slice"""
data_shape = F.shape(data)
begin_strides, end_strides, step_strides, shrink_axis_mask = \
const_utils.get_stride_info_from_tuple(data_shape, tuple_index)
begin_strides, end_strides, step_strides, shrink_axis_mask = const_utils.get_stride_info_from_tuple(
data_shape, tuple_index)
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)


@@ -259,8 +286,8 @@ def _generate_indices_from_tuple(data, tuple_index, op_name):
tuple_index_len = len(tuple_index)
tensor_indexes, slice_indexes = [], []
indexes_types = hyper_map(F.typeof, tuple_index)
slice_positions, _, _, int_positions, _, \
tensor_positions, sequence_positions = const_utils.get_pos_of_indexes_types(indexes_types, op_name)
slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
const_utils.get_pos_of_indexes_types(indexes_types, op_name)
tuple_index_new = ()

for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
@@ -296,8 +323,8 @@ def _generate_indices_from_tuple(data, tuple_index, op_name):
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_len):
if i in tensor_positions:
transform_tensor = _transform_indexing_tensor(
broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i])
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
tuple_index_new[i])
final_index_tensors.append(transform_tensor)
if i in slice_positions:
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
@@ -321,6 +348,7 @@ def _generate_updates_from_tuple(data, index, value, op_type):
value_types = hyper_map(F.typeof, value)
data_dtype = F.dtype(data)
value_elements_type = const_utils.check_value_elements(data_dtype, value_types)

if value_elements_type == const_utils.ALL_TENSOR:
value_shapes = hyper_map(F.shape, value)
shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM)


+ 77
- 72
mindspore/ops/composite/multitype_ops/_constexpr_utils.py View File

@@ -73,6 +73,13 @@ def make_empty_slice():
return slice(None, None, None)


@constexpr
def make_tensor(data, data_type, data_shape=None):
if data_shape:
return Tensor(np.zeros(data_shape), data_type)
return Tensor(data, data_type)


@constexpr
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
"""Checks the shape and size of the sensor and value."""
@@ -158,6 +165,36 @@ def check_indexes_types_valid(dtypes, target_type, op_name):
f"but got {dtype}.")


@constexpr
def get_pos_of_indexes_types(indexes_types, op_name):
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, tensor_positions, \
sequence_positions = [], [], [], [], [], [], []
for i, index_type in enumerate(indexes_types):
if isinstance(index_type, mstype.slice_type):
slice_positions.append(i)
elif isinstance(index_type, mstype.ellipsis_type):
ellipsis_positions.append(i)
elif isinstance(index_type, mstype.none_type):
none_positions.append(i)
elif isinstance(index_type, mstype.Int):
int_positions.append(i)
elif isinstance(index_type, mstype.bool_type):
bool_positions.append(i)
elif isinstance(index_type, mstype.tensor_type):
tensor_positions.append(i)
elif isinstance(index_type, (list, tuple)):
sequence_positions.append(i)
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.")
if len(ellipsis_positions) > 1:
raise IndexError(f"For '{op_name}, an index can only have a single ellipsis('...')")

return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \
tensor_positions, sequence_positions


def slice_expand(input_slices, shape):
"""
Converts slice to indices.
@@ -223,7 +260,7 @@ def ellipsis2slice(input_, shape):
return tuple(result)


@constexpr
@ constexpr
def slice2indices(input_slices, shape):
"""
Converts slice to indices.
@@ -248,7 +285,7 @@ def slice2indices(input_slices, shape):
return ravel


@constexpr
@ constexpr
def check_indices(indices_size, index):
"""Checks indices whether is empty."""
if indices_size < 1:
@@ -257,7 +294,7 @@ def check_indices(indices_size, index):
return indices_size


@constexpr
@ constexpr
def check_indices_value_size(indices_size, value_size):
"""Checks if the sizes are already matched."""
if value_size < 1:
@@ -270,7 +307,7 @@ def check_indices_value_size(indices_size, value_size):
return value_size


@constexpr
@ constexpr
def integer_to_indices(index, shape):
"""Converts int or tuple[int] to indices."""
size = reduce(lambda x, y: x * y, shape)
@@ -280,7 +317,7 @@ def integer_to_indices(index, shape):
return Tensor(value, dtype=mstype.int32)


@constexpr
@ constexpr
def tuple_element_is_int(indexs):
"""Judges tuple element type."""
if not indexs:
@@ -293,21 +330,14 @@ def tuple_element_is_int(indexs):
return False


@constexpr
def tuple_index_tensor_cnt(types, op_name):
"""count the tensor type of types which contains the tuple elements' type."""
tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types)
return ALL_TENSOR if tensor_cnt == len(types) else NO_TENSOR if tensor_cnt == 0 else CONTAIN_TENSOR


@constexpr
@ constexpr
def tuple_index_int_cnt(types, op_name):
"""count the int type of types which contains the tuple elements' type."""
int_cnt = sum(isinstance(ele, mstype.Int) for ele in types)
return ALL_INT if int_cnt == len(types) else NO_INT if int_cnt == 0 else CONTAIN_INT


@constexpr
@ constexpr
def tuple_index_type_cnt(types, op_name):
"""count the tensor type of types which contains the tuple elements' type."""
tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types)
@@ -319,7 +349,7 @@ def tuple_index_type_cnt(types, op_name):
return MIXED


@constexpr
@ constexpr
def check_value_elements(data_dtype, types):
"""Judges the type of all elements of the tuple."""
tensors_number = 0
@@ -344,8 +374,10 @@ def check_value_elements(data_dtype, types):
raise TypeError(
f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.")

# TODO to del

@constexpr

@ constexpr
def get_index_tensor_dtype(dtype):
"""Check a tuple of tensor data type."""
if dtype == mstype.int32:
@@ -356,7 +388,8 @@ def get_index_tensor_dtype(dtype):
f"For '{TENSOR_SETITEM}', the index tensor data type '{dtype}' is not supported.")


@constexpr
# TODO to del
@ constexpr
def check_index_tensors_dtype(indexes_types, op_name):
"""Check a tuple of tensor data type."""
for index_type in indexes_types:
@@ -366,7 +399,8 @@ def check_index_tensors_dtype(indexes_types, op_name):
return True


@constexpr
# TODO to del
@ constexpr
def check_index_tensor_dtype(index_type, op_name):
"""Check a tensor data type."""
if index_type in (mstype.int32, mstype.int64):
@@ -375,7 +409,8 @@ def check_index_tensor_dtype(index_type, op_name):
f"but got {index_type}.")


@constexpr
# TODO to del
@ constexpr
def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
"""Check tensors data type same."""
if value_dtype == data_dtype:
@@ -384,7 +419,7 @@ def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
f"is not consistent with assigned tensor data type {data_dtype}.")


@constexpr
@ constexpr
def generate_broadcast_shape(shapes, op_name):
"""Generate broadcast shape for a tuple of shape."""
if not shapes:
@@ -399,7 +434,7 @@ def generate_broadcast_shape(shapes, op_name):
return tuple(broadcast_shape)


@constexpr
@ constexpr
def check_two_shapes_need_broadcast(shape_x, shape_y):
"""Check two shapes need broadcast."""
error = ValueError(f"For 'tensor setitem with tensor', the value tensor shape "
@@ -416,14 +451,14 @@ def check_two_shapes_need_broadcast(shape_x, shape_y):
return True


@constexpr
@ constexpr
def compute_multiples(origin_shape, broadcast_shape):
"""Compute multiples between origin shape with broadcast shape."""
len_gap = len(broadcast_shape) - len(origin_shape)
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape))


@constexpr
@ constexpr
def compute_new_shape(origin_shape, indexes_shapes_info):
"""Compute new shape between origin shape with final shape."""
new_shape = []
@@ -435,7 +470,7 @@ def compute_new_shape(origin_shape, indexes_shapes_info):
return tuple(new_shape)


@constexpr
@ constexpr
def check_sequence_index_type(sequence_index, op_name):
"""check if the item's type of list_index is bool or int"""
if not all([isinstance(index, (int, bool)) for index in sequence_index]):
@@ -443,13 +478,13 @@ def check_sequence_index_type(sequence_index, op_name):
f"but got {type(index)} in array")


@constexpr
@ constexpr
def convert_int_to_slice(tuple_index):
tuple_index_new = tuple(slice(i, i+1, 1) for i in tuple_index)
return tuple_index_new


@constexpr
@ constexpr
def check_and_transform_int_index(index, shape, op_name):
if index < -shape or index >= shape:
raise IndexError(f"In the \"{op_name}\", the index should in the range [-{shape}, {shape-1}] to fit "
@@ -459,7 +494,7 @@ def check_and_transform_int_index(index, shape, op_name):
return index


@constexpr
@ constexpr
def transform_sequence_index(sequence_index, shape, op_name):
"""transform list or tuple with integer and boolean to tuple with integer index"""
bool_count = len(list(filter(lambda index: isinstance(index, bool), sequence_index)))
@@ -477,7 +512,7 @@ def transform_sequence_index(sequence_index, shape, op_name):
return sub_tuple_index


@constexpr
@ constexpr
def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name):
"""Convert a slice to a tensor."""
shape = []
@@ -505,7 +540,7 @@ def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_n
return tensor


@constexpr
@ constexpr
def check_shapes_same(value_shapes, op_name):
"""Check if the shapes in the tuple are consistent."""
for i, shape in enumerate(value_shapes):
@@ -515,7 +550,7 @@ def check_shapes_same(value_shapes, op_name):
return True


@constexpr
@ constexpr
def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type):
"""Convert a scalar to a tensor."""
if op_type == SET_ITEM_BY_ONE_TENSOR:
@@ -528,7 +563,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty
f" is not consistent with the assigned tensor data type {data_dtype}.")


@constexpr
@ constexpr
def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type):
"""Convert a tuple of scalar to a tensor."""
updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
@@ -540,7 +575,7 @@ def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value
return Tensor(np.tile(array, reps))


@constexpr
@ constexpr
def generate_updates_shape(data_shape, index_shape, op_type):
"""Generate updates shape for 'tensor setitem'."""
if op_type == SET_ITEM_BY_ONE_TENSOR:
@@ -550,7 +585,7 @@ def generate_updates_shape(data_shape, index_shape, op_type):
return updates_shape


@constexpr
@ constexpr
def check_tuple_index_len(data_rank, tuple_index_len, op_name):
"""Check if the number of index tensor exceeds the dimension of the operated tensor."""
if tuple_index_len <= data_rank:
@@ -559,7 +594,7 @@ def check_tuple_index_len(data_rank, tuple_index_len, op_name):
f"is greater than the dimension {data_rank} of the operated tensor.")


@constexpr
@ constexpr
def generate_index_info_from_tuple_of_mixed_tensors(data_shape, indexes_types, tensor_indexes_shapes,
tensor_indexes_dtypes, slice_indexes, op_name):
"""
@@ -645,37 +680,7 @@ def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_te
return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info)


@constexpr
def get_pos_of_indexes_types(indexes_types, op_name):
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, tensor_positions, \
sequence_positions = [], [], [], [], [], [], []
for i, index_type in enumerate(indexes_types):
if isinstance(index_type, mstype.slice_type):
slice_positions.append(i)
elif isinstance(index_type, mstype.ellipsis_type):
ellipsis_positions.append(i)
elif isinstance(index_type, mstype.none_type):
none_positions.append(i)
elif isinstance(index_type, mstype.Int):
int_positions.append(i)
elif isinstance(index_type, mstype.bool_type):
bool_positions.append(i)
elif isinstance(index_type, mstype.tensor_type):
tensor_positions.append(i)
elif isinstance(index_type, (list, tuple)):
sequence_positions.append(i)
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.")
if len(ellipsis_positions) > 1:
raise IndexError(f"For '{op_name}, an index can only have a single ellipsis('...')")

return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \
tensor_positions, sequence_positions


@constexpr
@ constexpr
def scalar_in_sequence(x, y):
"""Determine whether the scalar in the sequence."""
if x is None:
@@ -689,14 +694,14 @@ def scalar_in_sequence(x, y):
return False


@constexpr
@ constexpr
def get_np_eps(input_dtype):
nptype = mstype.dtype_to_nptype(input_dtype)
eps = np.finfo(nptype).eps
return float(eps)


@constexpr
@ constexpr
def check_number_index_type(number):
"""Check if it is int or bool number"""
if isinstance(number, bool):
@@ -707,7 +712,7 @@ def check_number_index_type(number):
.format(number, type(number)))


@constexpr
@ constexpr
def get_stride_info_from_slice(data_shape, slice_index):
"""Get stride info from a python slice"""
begin, end, step = get_slice_stride(data_shape[0], slice_index)
@@ -721,7 +726,7 @@ def get_stride_info_from_slice(data_shape, slice_index):
return tuple(begin_strides), tuple(end_strides), tuple(step_strides)


@constexpr
@ constexpr
def get_stride_info_from_integer(data_shape, number):
"""Get stride info from a integer"""
begin_strides = [number]
@@ -747,7 +752,7 @@ def get_slice_stride(dim_size, index_slice):
return start, stop, step


@constexpr
@ constexpr
def get_stride_info_from_tuple(data_shape, tuple_index):
"""Get stride info from a tuple"""
begin_strides, end_strides, step_strides = [], [], []
@@ -787,14 +792,14 @@ def get_stride_info_from_tuple(data_shape, tuple_index):
return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis


@constexpr
@ constexpr
def mstype_eq(x, y):
if x == y:
return True
return False


@constexpr
@ constexpr
def scalar_to_tensor(x):
"""Convert a scalar to a tensor"""
return Tensor(x)

+ 24
- 4
tests/ut/python/ops/test_tensor_fancy_index.py View File

@@ -14,6 +14,7 @@
# ============================================================================
""" test_tensor_slice """
import numpy as np
import pytest

from mindspore import Tensor
from mindspore import context
@@ -48,7 +49,7 @@ def test_tensor_fancy_index_boolean_list():
net(input_me)


def test_tensor_fancy_integer_boolean_list_graph():
def test_tensor_fancy_index_integer_boolean_list_graph():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = [1, 2, True, False]
net = NetWorkFancyIndex(index)
@@ -57,7 +58,7 @@ def test_tensor_fancy_integer_boolean_list_graph():
net(input_me)


def test_tensor_fancy_integer_list_mixed():
def test_tensor_fancy_index_integer_list_mixed():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = (1, [2, 1, 3], slice(1, 3, 1), ..., 4)
net = NetWorkFancyIndex(index)
@@ -66,7 +67,7 @@ def test_tensor_fancy_integer_list_mixed():
net(input_me)


def test_tensor_fancy_integer_tuple_mixed():
def test_tensor_fancy_index_integer_tuple_mixed():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = (1, (2, 1, 3), slice(1, 3, 1), ..., 4)
net = NetWorkFancyIndex(index)
@@ -75,10 +76,29 @@ def test_tensor_fancy_integer_tuple_mixed():
net(input_me)


def test_tensor_fancy_integer_list_tuple_mixed():
def test_tensor_fancy_index_integer_list_tuple_mixed():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = (1, [2, 1, 3], (3, 2, 1), slice(1, 3, 1), ..., 4)
net = NetWorkFancyIndex(index)
input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8)
input_me = Tensor(input_np, dtype=mstype.float32)
net(input_me)


def test_tensor_fancy_index_integer_list_tuple_bool_mixed():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = (1, [2, 1, 3], True, (3, 2, 1), slice(1, 3, 1), ..., True, 4)
net = NetWorkFancyIndex(index)
input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8)
input_me = Tensor(input_np, dtype=mstype.float32)
net(input_me)


def test_tensor_fancy_index_integer_list_tuple_bool_mixed_error():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
index = (1, [2, 1, 3], True, (3, 2, 1), slice(1, 3, 1), ..., False, 4)
net = NetWorkFancyIndex(index)
input_np = np.arange(3*4*5*6*7*8).reshape(3, 4, 5, 6, 7, 8)
input_me = Tensor(input_np, dtype=mstype.float32)
with pytest.raises(IndexError):
net(input_me)

Loading…
Cancel
Save