Browse Source

Support single bracket setitem

pull/14166/head
yanglf1121 4 years ago
parent
commit
15820776fc
5 changed files with 994 additions and 164 deletions
  1. +258
    -113
      mindspore/ops/composite/multitype_ops/_compile_utils.py
  2. +149
    -35
      mindspore/ops/composite/multitype_ops/_constexpr_utils.py
  3. +369
    -13
      mindspore/ops/composite/multitype_ops/setitem_impl.py
  4. +3
    -3
      tests/st/pynative/test_tensor_index.py
  5. +215
    -0
      tests/st/pynative/test_tensor_setitem.py

+ 258
- 113
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -23,7 +23,7 @@ from ....common import dtype as mstype
from ....common._register_for_tensor import tensor_operator_registry

hyper_map = base.HyperMap()
pack = P.Stack(axis=-1)
stack = P.Stack(axis=-1)


def _tensor_getitem(self, index):
@@ -36,44 +36,35 @@ def _tensor_getitem(self, index):
return tensor_index_by_tuple(self, (index,))
raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor with int, "
f"list and tuple ,but got {index} with type {type(index)}.")
tensor_operator_registry.register("__getitem__", _tensor_getitem)

def _tensor_setitem(self, index, value):
"""Handle tensor getitem"""
"""Handle tensor setitem"""
if not isinstance(value, (int, float, bool, list, tuple, Tensor)):
raise ValueError(f"only support numbers, Tensor, tuple, list as value,"
f"but got {value} with type {type(value)}.")

if isinstance(index, list):
index = format_list_indices(index, self.shape[0])
if isinstance(index, Tensor):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_tensor_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_tensor_with_tensor(self, index, value)
if isinstance(value, tuple):
return tensor_setitem_by_tensor_with_tuple(self, index, value)
return tensor_setitem_by_tensor(self, index, value)
if isinstance(index, tuple):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_tuple_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_tuple_with_tensor(self, index, value)
if isinstance(value, tuple):
return tensor_setitem_by_tuple_with_tuple(self, index, value)
if tuple_indices_have_false(index):
return self
index = format_tuple_indices(index)
return tensor_setitem_by_tuple(self, index, value)
if isinstance(index, bool):
return tensor_setitem_by_bool(self, index, value)
if isinstance(index, int):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_number_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_number_with_tensor(self, index, value)
return tensor_setitem_by_number(self, index, value)
if isinstance(index, slice):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_slice_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_slice_with_tensor(self, index, value)
if isinstance(index, bool):
return _tensor_index_by_bool(self, index)
return tensor_setitem_by_slice(self, index, value)
if index is ...:
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_ellipsis_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_ellipsis_with_tensor(self, index, value)
return tensor_setitem_by_ellipsis(self, index, value)

raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\
and tensor with int32, got {} with type{}".format(index, type(index)))
tensor_operator_registry.register("__setitem__", _tensor_setitem)

def _broadcast(broadcast_shape, x):
"""Broadcast tensor to the required shape."""
@@ -103,7 +94,8 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name):
ellipsis_occupy_dims = data_rank - (len(slice_positions) + len(int_positions) +
len(tensor_positions) + len(sequence_positions))
ellipsis_cnt = len(ellipsis_positions)
if (ellipsis_cnt == 0 and ellipsis_occupy_dims < 0) or (ellipsis_cnt > 0 and ellipsis_occupy_dims < 1):
# pylint: disable=chained-comparison
if ellipsis_occupy_dims < 0 and ellipsis_cnt >= 0:
const_utils.raise_index_error("For the 'getitem Operator', the data_shape should be no less than the "
"tuple index dims")

@@ -155,14 +147,6 @@ def tensor_index_by_number(data, number_index):
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.")


# TODO wait to remove after setitem by Yang Linfeng
def _tensor_index_by_bool(data, bool_index):
"""Tensor getitem by a single bool value"""
if bool_index:
return F.expand_dims(data, 0)
return const_utils.make_tensor([], data.dtype, (0,) + F.shape(data))


def _tensor_index_by_integer(data, int_index):
"""Tensor getitem by a single integer number"""
data_shape = F.shape(data)
@@ -218,6 +202,31 @@ def tensor_index_by_tuple(data, tuple_index):
return _tensor_getitem_by_tuple(data, tuple_index, op_name)


def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name):
"""Tensor getitem by a tuple of tensor."""
data_shape = F.shape(data)
tuple_index_len = len(tuple_index)

indexes_types = hyper_map(F.dtype, tuple_index)
const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name)
tensor_index_shape = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
if 0 in broadcast_shape:
res_shape = broadcast_shape
if tuple_index_len < len(data_shape):
res_shape += data_shape[tuple_index_len:]
res = const_utils.make_tensor([], data.dtype, res_shape)
return res

broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
new_broadcast_tensors = ()
for tensor in broadcast_tensors:
new_broadcast_tensors += (F.cast(tensor, mstype.int64),)
indices = stack(new_broadcast_tensors)
result = F.gather_nd(data, indices)
return result


def _tensor_getitem_by_tuple_slice(data, tuple_index):
"""Tensor getitem by a tuple of slice"""
data_shape = F.shape(data)
@@ -284,8 +293,9 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name):
final_index_tensors.append(slice_index_tensor)
slice_cnt += 1

indices = pack(final_index_tensors)
return F.gather_nd(data, indices)
indices = stack(final_index_tensors)
result = F.gather_nd(data, indices)
return result


def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
@@ -299,7 +309,7 @@ def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
new_broadcast_tensors = ()
for tensor in broadcast_tensors:
new_broadcast_tensors += (F.cast(tensor, mstype.int64),)
indices = pack(new_broadcast_tensors)
indices = stack(new_broadcast_tensors)
return indices


@@ -332,6 +342,11 @@ def _generate_indices_from_tuple(data, tuple_index, op_name):
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
elif i in slice_positions:
start, stop, _ = const_utils.slice_to_tuple(index)
start = const_utils.normalize_start(start, dim_size)
stop = const_utils.normalize_stop(stop, dim_size)
if start >= stop:
return None
slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size)
slice_shapes += (len(slice_ele_list_index),)
tuple_index_new += (slice_ele_list_index,)
@@ -354,7 +369,7 @@ def _generate_indices_from_tuple(data, tuple_index, op_name):
final_index_tensors.append(slice_index_tensor)
slice_cnt += 1

indices = pack(final_index_tensors)
indices = stack(final_index_tensors)
return indices


@@ -366,44 +381,76 @@ def _generate_updates_from_scalar(data, indices, value, op_type):
return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type)


def _generate_updates_from_tuple(data, index, value, op_type):
"""Generate an updates tensor from a tuple."""
def _generate_updates_from_sequence(data, index, value, op_type):
"""Generate an updates tensor from a tuple, can only handle 1-D tensor/non-tensor mixtures."""
value_types = hyper_map(F.typeof, value)
data_dtype = F.dtype(data)
value_elements_type = const_utils.check_value_elements(data_dtype, value_types)
value_elements_type = const_utils.check_value_elements(value_types)

if value_elements_type == const_utils.ALL_TENSOR:
value_shapes = hyper_map(F.shape, value)
shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM)
if shapes_same:
value = F.stack(value)
return _generate_updates_from_tensor(data, index, value, op_type)

data_shape = F.shape(data)
index_shape = F.shape(index)
return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type)
value = F.stack(value).astype(data.dtype)
elif value_elements_type == const_utils.NO_TENSOR:
value = const_utils.make_tensor(value, data.dtype)
else:
new_value = ()
for ele in value:
ele = ele if isinstance(ele, Tensor) else const_utils.make_tensor(ele)
new_value += (ele,)
value = F.stack(new_value).astype(data.dtype)
if op_type == const_utils.SET_ITEM_BY_NON_TENSOR:
return value
return _generate_updates_from_tensor(data, index, value, op_type)


def _generate_updates_from_tensor(data, index, value, op_type):
"""Generate an updates tensor from a tensor."""
data_shape = F.shape(data)
index_shape = F.shape(index)
value_shape = F.shape(value)
data_dtype = F.dtype(data)
value_dtype = F.dtype(value)
updates_shape = value_shape
check_dtype_same = const_utils.check_tensors_dtype_same(data_dtype, value_dtype, const_utils.TENSOR_SETITEM)
if check_dtype_same:
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type)
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape)
value = value.astype(data.dtype)
updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type)
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value.shape)
if need_broadcast:
return _broadcast(updates_shape, value)
return value


tensor_operator_registry.register("__getitem__", _tensor_getitem)
# Tensor getitem implementations are above this line, setitem implementations below.

tensor_operator_registry.register("__setitem__", _tensor_setitem)
def tensor_setitem_by_tensor(self, index, value):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_tensor_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_tensor_with_tensor(self, index, value)
return tensor_setitem_by_tensor_with_sequence(self, index, value)


def tensor_setitem_by_tuple(self, index, value):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_tuple_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_tuple_with_tensor(self, index, value)
return tensor_setitem_by_tuple_with_sequence(self, index, value)


def tensor_setitem_by_number(self, index, value):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_number_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_number_with_tensor(self, index, value)
return tensor_setitem_by_number_with_sequence(self, index, value)


def tensor_setitem_by_slice(self, index, value):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_slice_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_slice_with_tensor(self, index, value)
return tensor_setitem_by_slice_with_sequence(self, index, value)


def tensor_setitem_by_ellipsis(self, index, value):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_ellipsis_with_number(self, value)
if isinstance(value, Tensor):
return tensor_setitem_by_ellipsis_with_tensor(self, value)
return tensor_setitem_by_ellipsis_with_sequence(self, value)


def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
@@ -469,17 +516,16 @@ def tensor_setitem_by_tensor_with_number(data, index, value):
return const_utils.raise_index_error("For tensor setitem, indexing tensor dtype only supports bool/int")


def tensor_setitem_by_tensor_with_tuple(data, index, value):
def tensor_setitem_by_tensor_with_sequence(data, index, value):
"""Assigns the tensor by tensor with tuple value."""
index_dtype = F.dtype(index)
const_utils.check_type_valid(index_dtype, (mstype.int32, mstype.int64), const_utils.TENSOR_SETITEM)
result = _tensor_setitem_by_tensor_with_tuple(data, index, value)
return result
return _tensor_setitem_by_tensor_with_sequence(data, index, value)


def _tensor_indices_number(data, data_shape, index, indices, value):
"""Assigns a scalar value to the tensor."""
data_size = F.size(data)
data_size = F.shape_mul(data.shape)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = const_utils.check_indices(indices_size, index)
@@ -493,9 +539,9 @@ def _tensor_indices_number(data, data_shape, index, indices, value):
return F.select(condition, u, data)


def _tensor_setitem_by_tensor_with_tuple(data, index, value):
def _tensor_setitem_by_tensor_with_sequence(data, index, value):
"""Set a tensor item by a tensor with a tuple."""
updates = _generate_updates_from_tuple(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
updates = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
index = F.expand_dims(index, -1)
return P.TensorScatterUpdate()(data, index, updates)

@@ -507,6 +553,8 @@ def tensor_setitem_by_slice_with_number(data, input_slice, value):
if check_result:
data_shape = F.shape(data)
indices = const_utils.slice2indices(input_slice, data_shape)
if indices is False:
return data
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = const_utils.integer_to_indices(input_slice, data_shape)
@@ -516,6 +564,8 @@ def tensor_setitem_by_slice_with_number(data, input_slice, value):

def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
"""Assigns the tensor by tuple with number value."""
tuple_index = ignore_dim_expand(tuple_index)

if len(tuple_index) == 1:
data[tuple_index[0]] = value
return data
@@ -533,13 +583,15 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
if int_cnt == const_utils.ALL_INT:
tuple_index = const_utils.convert_int_to_slice(tuple_index)
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM)
if indices is None:
return data
updates = _generate_updates_from_scalar(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates)


def _tensor_indices_tensor(data, data_shape, index, indices, value):
"""Assigns a tensor value to the tensor."""
data_size = F.size(data)
data_size = F.shape_mul(data.shape)
data_dtype = F.dtype(data)
indices_size = F.size(indices)
indices_size = const_utils.check_indices(indices_size, index)
@@ -548,7 +600,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value):
condition = F.reshape(condition_1d, data_shape)
condition = F.cast(condition, mstype.bool_)
value_fill = None
value_size = F.size(value)
value_size = value.size

value_size = const_utils.check_indices_value_size(indices_size, value_size)
if value_size == 1:
@@ -559,7 +611,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value):
value_fill = F.reshape(value, (indices_size,))
value_1d = F.scatter_nd(indices, value_fill, (data_size,))
u = F.reshape(value_1d, data_shape)
return F.select(condition, u, data)
return F.select(condition, u.astype(data_dtype), data)


def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
@@ -569,6 +621,8 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
if check_result:
data_shape = F.shape(data)
indices = const_utils.slice2indices(input_slice, data_shape)
if indices is False:
return data
is_tuple_int = const_utils.tuple_element_is_int(input_slice)
if is_tuple_int:
indices = const_utils.integer_to_indices(input_slice, data_shape)
@@ -576,8 +630,18 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
return result


def tensor_setitem_by_slice_with_sequence(data, input_slice, value):
"""Assigns a list/tuple value to the tensor by slice."""
value = _generate_updates_from_sequence(data, input_slice, value, const_utils.SET_ITEM_BY_NON_TENSOR)
return tensor_setitem_by_slice_with_tensor(data, input_slice, value)


def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
"""Assigns the tensor by tuple with tensor value."""
value_shape = remove_ignored_dim(tuple_index, F.shape(value), F.rank(data))
value = F.reshape(value, value_shape)
tuple_index = ignore_dim_expand(tuple_index)

if len(tuple_index) == 1:
data[tuple_index[0]] = value
return data
@@ -600,31 +664,15 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
new_shape += value.shape
value = F.reshape(value, new_shape)
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM)
if indices is None:
return data
updates = _generate_updates_from_tensor(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates)


def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
"""Assigns the tensor by tuple with tuple of value."""
if len(tuple_index) == 1:
data[tuple_index[0]] = value
return data
op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims(data, tuple_index)

indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)

if contain_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM)
else:
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if int_cnt == const_utils.ALL_INT:
tuple_index = const_utils.convert_int_to_slice(tuple_index)
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_tuple(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates)
def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value):
value = _generate_updates_from_sequence(data, tuple_index, value, const_utils.SET_ITEM_BY_NON_TENSOR)
return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)


def tensor_setitem_by_number_with_number(data, index, value):
@@ -634,6 +682,12 @@ def tensor_setitem_by_number_with_number(data, index, value):
return _tensor_indices_number(data, data_shape, index, indices, value)


def tensor_setitem_by_number_with_sequence(data, index, value):
"""Assigns a list/tuple value to the tensor by slice."""
value = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_NON_TENSOR)
return tensor_setitem_by_number_with_tensor(data, index, value)


def tensor_setitem_by_number_with_tensor(data, index, value):
"""Assigns the tensor by number with tensor value."""
data_shape = F.shape(data)
@@ -641,31 +695,46 @@ def tensor_setitem_by_number_with_tensor(data, index, value):
return _tensor_indices_tensor(data, data_shape, index, indices, value)


def tensor_setitem_by_ellipsis_with_number(data, index, value):
def tensor_setitem_by_ellipsis_with_number(data, value):
"""Assigns the tensor by ellipsis with number value."""
data_shape = F.shape(data)
data_dtype = F.dtype(data)
return F.fill(data_dtype, data_shape, value)


def tensor_setitem_by_ellipsis_with_tensor(data, index, value):
def tensor_setitem_by_ellipsis_with_tensor(data, value):
"""Assigns the tensor by ellipsis with tensor value."""
result = None
data_shape = F.shape(data)
data_dtype = F.dtype(data)
data_size = F.size(data)
value = value.astype(data_dtype)
value_shape = F.shape(value)
value_size = F.size(value)
check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size)
if check_result:
if data_size == value_size:
result = F.reshape(value, data_shape)
result = F.cast(result, data_dtype)
elif value_size == 1:
param1 = F.fill(data_dtype, data_shape, 1)
param2 = F.cast(value, data_dtype)
result = F.tensor_mul(param1, param2)
return result
source_shape = const_utils.get_source_shape(data_shape, value_shape)
value = F.reshape(value, source_shape)
value = _broadcast(data_shape, value)
data = F.cast(value, data_dtype)
return data


def tensor_setitem_by_ellipsis_with_sequence(data, value):
"""Assigns a list/tuple value to the tensor by ellipsis."""
value = _generate_updates_from_sequence(data, None, value, const_utils.SET_ITEM_BY_NON_TENSOR)
return tensor_setitem_by_ellipsis_with_tensor(data, value)


def tensor_setitem_by_bool(data, index, value):
"""Assigns a value to the tensor by boolean."""
data_shape = F.shape(data)
if not index:
data_shape = (0,) + data_shape
if not isinstance(value, Tensor):
value = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_NON_TENSOR)
value_shape = F.shape(value)
source_shape = const_utils.get_source_shape(data_shape, value_shape)
if index:
value = F.reshape(value, source_shape)
value = _broadcast(data_shape, value)
data = value
return data


def tensor_in_sequence(x, y):
@@ -675,3 +744,79 @@ def tensor_in_sequence(x, y):
if isinstance(i, Tensor) and x.shape == i.shape and x.dtype == i.dtype:
result = F.logical_or(F.equal(x, i).all(), result)
return result


def format_list_indices(list_indices, length):
"""Convert list indices to tensor or tuple indices based on its contents."""
indices_types = hyper_map(F.typeof, list_indices)
# If eyery element in list is bool, it's treated as 1-D bool tensor.
# If every element in list is int(not all bool), it's treated as int tensor.
if const_utils.judge_indexes_types(indices_types, mstype.int_type+(mstype.bool_,)):
list_indices = const_utils.transform_sequence_index(list_indices, length, const_utils.TENSOR_SETITEM)
return const_utils.make_tensor(list_indices)
# If list contains other types(.../list/tuple/None), it's treated as a tuple
return const_utils.deep_tuple(list_indices)


def format_tuple_indices(tuple_indices):
"""
Format tuple indices by unpacking high-dimension tuple and removing expand
dimension signs(Bool and None).
"""
res = ()
for i in tuple_indices:
if isinstance(i, (list, tuple)):
res += (const_utils.unpack(i),)
else:
res += (i,)
return res


def tuple_indices_have_false(tuple_indices):
"""Returns True if tuple_indices contains False."""
for i in tuple_indices:
if i is False:
return True
return False


def ignore_dim_expand(idx):
"""Filters flags for dimension expansion from idx."""
res = ()
for i in idx:
if not i is True and not i is None:
res += (i,)
if not res:
res = (True,)
return res


def remove_ignored_dim(idx, value_shape, data_rank):
"""Removes dimensions in value that correspond to dimension expansion flags in index."""
has_ellipsis = False
has_true = False
cnt_trailing_expanded = 0
cnt_not_dim_expand = 0
for i in idx:
if not i is True and not i is None:
cnt_not_dim_expand += 1
if const_utils.is_ellipsis(i):
has_ellipsis = True
elif has_ellipsis:
if i is None:
cnt_trailing_expanded += 1
elif i is True and not has_true:
has_true = True
if has_true and cnt_not_dim_expand + 1 < data_rank:
cnt_trailing_expanded += 1

if cnt_trailing_expanded == 0:
return value_shape
value_expanded_pos = len(value_shape) - cnt_trailing_expanded
value_expanded_not_unit = False
for i in value_shape[value_expanded_pos:]:
if i != 1:
value_expanded_not_unit = True
if value_expanded_pos < 0 or value_expanded_not_unit:
const_utils.raise_value_error('shape mismatch')
return value_shape[:value_expanded_pos]

+ 149
- 35
mindspore/ops/composite/multitype_ops/_constexpr_utils.py View File

@@ -43,6 +43,7 @@ TENSOR_GETITEM = "tensor getitem"

SET_ITEM_BY_ONE_TENSOR = 0
SET_ITEM_BY_TUPLE_OF_TENSOR = 1
SET_ITEM_BY_NON_TENSOR = 2


@constexpr
@@ -74,10 +75,85 @@ def make_empty_slice():


@constexpr
def make_tensor(data, data_type=mstype.int64, data_shape=None):
def _deep_list(array_like):
"""convert nested tuple/list mixtures to pure nested list"""
if isinstance(array_like, (list, tuple)):
return list(map(_deep_list, array_like))
return array_like


@constexpr
def deep_tuple(array_like):
"""convert nested tuple/list mixtures to pure nested tuple"""
if isinstance(array_like, (list, tuple)):
return tuple(map(deep_tuple, array_like))
return array_like


def _deep_tensor_to_nparray(array_like):
"""
convert a nested list of tensor to nested list of np_array.

Args:
array_like(list(tensor)): In any format of nested lists that may contain
tensors.

Returns:
array_like(list(np_array)): Formatted array that can be directly processed
by numpy.array(), with all tensor elements converted to numpy_array.
"""
# Recursively check whether each element is a tensor or not, if is tensor,
# convert it to a numpy array in place
if isinstance(array_like, Tensor):
return array_like.asnumpy()

if isinstance(array_like, list):
for idx, value in enumerate(array_like):
array_like[idx] = _deep_tensor_to_nparray(value)

return array_like


@constexpr
def make_tensor(a, dtype=mstype.int32, data_shape=None):
"""
Converts the input to tensor.

This function converts tensors from an array-like object.

Args:
a (Union[int, float, bool, list, tuple]): Input data, in any form that can
be converted to a `Tensor`.
dtype (:class:`mindspore.dtype`): Designated tensor dtype.

Returns:
Tensor, generated tensor with the specified dtype.

Raises:
TypeError: If input arguments have types not specified above.
ValueError: If input `a` has different sizes at different dimensions.
"""

if data_shape:
return Tensor(np.zeros(data_shape), data_type)
return Tensor(data, data_type)
return Tensor(np.zeros(data_shape), dtype)

if not isinstance(a, (list, tuple, int, float, bool)):
raise TypeError("input data must be `int`, `float`, `bool`, `list` or `tuple`")

if isinstance(a, (list, tuple)):
# Convert all tuple/nested tuples to lists
a = _deep_list(a)
# Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a)
a = np.asarray(a)
if a.dtype is np.dtype('object'):
raise ValueError('Input array must have the same size across all dimensions.')

if isinstance(a, np.ndarray):
if a.dtype is np.dtype('object'):
raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")

return Tensor(a, dtype)


@constexpr
@@ -88,12 +164,20 @@ def judge_data_rank(data_rank, min_data_rank=0, max_data_rank=8):


@constexpr
def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size):
"""Checks the shape and size of the sensor and value."""
if data_shape == value_shape or data_size == value_size or value_size == 1:
return True
raise ValueError("The value(shape={}), can not assign to tensor(shape={}).".format(
value_shape, data_shape))
def get_source_shape(data_shape, value_shape):
"""Returns the shape of value that will be used to broadcast against data."""
cannot_broadcast = False
source_shape = value_shape
for i, j in zip(reversed(data_shape), reversed(value_shape)):
if j not in (1, i):
cannot_broadcast = True
for i in range(len(value_shape) - len(data_shape)):
source_shape = data_shape
if value_shape[i] != 1:
cannot_broadcast = True
if cannot_broadcast:
raise ValueError(f'could not broadcast input array from shape {value_shape} to {data_shape}')
return source_shape


@constexpr
@@ -288,8 +372,10 @@ def slice2indices(input_slices, shape):
begin, end, strides = slice_expand(input_slices, shape)
np_r = []
for i, element in enumerate(shape):
s = begin[i] if (begin[i] >= 0) else (element + begin[i])
e = end[i] if (end[i] >= 0) else (element + end[i])
s = normalize_start(begin[i], element)
e = normalize_stop(end[i], element)
if s >= e:
return False
np_r.append(np.r_[s:e:strides[i]])
# Reference: np.ravel_multi_index((np.ix_(np.r_[1:3:1], np.r_[0:4:1], np.r_[4:0:-1])), a.shape)
np_ix = np.ix_(*np_r)
@@ -364,29 +450,17 @@ def tuple_index_type_cnt(types, op_name):


@constexpr
def check_value_elements(data_dtype, types):
def check_value_elements(types):
"""Judges the type of all elements of the tuple."""
tensors_number = 0
scalars_number = 0
for i, ele in enumerate(types):
tensor_number = 0
for ele in types:
if isinstance(ele, mstype.tensor_type):
ele_dtype = ele.element_type()
if data_dtype == ele_dtype:
tensors_number += 1
else:
raise TypeError(f"For '{TENSOR_SETITEM}', the data type of {i}th tensor '{ele_dtype}' "
f"in value tuple is not consistent with assigned tensor data type '{data_dtype}'.")
elif mstype.dtype_to_pytype(ele) == mstype.dtype_to_pytype(data_dtype):
scalars_number += 1
else:
raise TypeError(f"For '{TENSOR_SETITEM}', the {i}th element type '{ele}' in "
f"value tuple is not consistent with assigned tensor data type '{data_dtype}'.")
if tensors_number == len(types):
tensor_number += 1
if tensor_number == 0:
return NO_TENSOR
if tensor_number == len(types):
return ALL_TENSOR
if scalars_number == len(types):
return ALL_SCALAR
raise TypeError(
f"For '{TENSOR_SETITEM}', the value does not support scalar and tensor mixing, but got {types}.")
return CONTAIN_TENSOR


@constexpr
@@ -528,10 +602,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty
updates_shape = indices_shape + data_shape[1:]
else:
updates_shape = indices_shape[:-1] + data_shape[indices_shape[-1]:]
if isinstance(value, mstype.dtype_to_pytype(data_dtype)):
return Tensor(np.full(updates_shape, value), dtype=data_dtype)
raise TypeError(f"For '{TENSOR_SETITEM}', the value type '{value.__class__.__name__}'"
f" is not consistent with the assigned tensor data type {data_dtype}.")
return Tensor(np.full(updates_shape, value), dtype=data_dtype)


@constexpr
@@ -716,3 +787,46 @@ def mstype_eq(x, y):
def scalar_to_tensor(x):
"""Convert a scalar to a tensor"""
return Tensor(x)


@constexpr
def unpack(x):
if isinstance(x, (tuple, list)) and len(x) == 1:
return unpack(x[0])
return x


@constexpr
def slice_to_tuple(s):
return (s.start, s.stop, s.step)


@constexpr
def normalize_start(start, dim_size):
"""
Normalize `start` according to the number of dimensions (`dim_size`).
If the number of dimensions is not given, return the original input directly.
"""
if start is None:
return 0
if start < 0:
return 0 if start < -dim_size else start % dim_size
return start if start < dim_size else dim_size


@constexpr
def normalize_stop(stop, dim_size):
"""
Normalize `stop` according to the number of dimensions (`dim_size`).
If the number of dimensions is not given, return the original input directly.
"""
if stop is None:
return dim_size
if stop < 0:
return 0 if stop < -dim_size else stop % dim_size
return stop if stop < dim_size else dim_size


@constexpr
def is_ellipsis(x):
return x is Ellipsis

+ 369
- 13
mindspore/ops/composite/multitype_ops/setitem_impl.py View File

@@ -18,6 +18,7 @@
from . import _compile_utils as compile_utils
from ... import functional as F
from ...composite import base
from ....common import Tensor

setitem = base.MultitypeFuncGraph('setitem')

@@ -213,6 +214,9 @@ def _tensor_setitem_by_tuple_with_number(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
if compile_utils.tuple_indices_have_false(tuple_index):
return data
tuple_index = compile_utils.format_tuple_indices(tuple_index)
return compile_utils.tensor_setitem_by_tuple_with_number(data, tuple_index, value)


@@ -234,6 +238,9 @@ def _tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
if compile_utils.tuple_indices_have_false(tuple_index):
return data
tuple_index = compile_utils.format_tuple_indices(tuple_index)
return compile_utils.tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)


@@ -246,21 +253,49 @@ def _tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
Syntax support: A[B, C, D] = U.
Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors.
2) A B and C could be broadcast.
3) U is a Tensor.
3) U is a Tuple.

Inputs:
data (Tensor): Assigned tensor.
index (Tuple): A tuple of tensor, these tensor could be broadcast.
value (Tensor): Assignment tensor, should has the same data type as 'data'.
value (Tuple): Assignment tuple.

Outputs:
Tensor, element type and shape is same as data.
"""
return compile_utils.tensor_setitem_by_tuple_with_tuple(data, tuple_index, value)
if compile_utils.tuple_indices_have_false(tuple_index):
return data
tuple_index = compile_utils.format_tuple_indices(tuple_index)
return compile_utils.tensor_setitem_by_tuple_with_sequence(data, tuple_index, value)


@setitem.register("Tensor", "Tuple", "List")
def _tensor_setitem_by_tuple_with_list(data, tuple_index, value):
"""
Tensor assignment.

Note:
Syntax support: A[B, C, D] = U.
Restraint condition: 1) A is a Tensor, and B, C, D are index Tensors.
2) A B and C could be broadcast.
3) U is a List.

Inputs:
data (Tensor): Assigned tensor.
index (Tuple): A tuple of tensor, these tensor could be broadcast.
value (List): Assignment tuple.

Outputs:
Tensor, element type and shape is same as data.
"""
if compile_utils.tuple_indices_have_false(tuple_index):
return data
tuple_index = compile_utils.format_tuple_indices(tuple_index)
return compile_utils.tensor_setitem_by_tuple_with_sequence(data, tuple_index, value)


@setitem.register("Tensor", "Tensor", "Tuple")
def _tensor_setitem_by_tensor_v2(data, index, value):
def _tensor_setitem_by_tensor_with_tuple(data, index, value):
"""
Tensor assignment.

@@ -272,11 +307,27 @@ def _tensor_setitem_by_tensor_v2(data, index, value):
Outputs:
Tensor, element type and shape is same as data.
"""
return compile_utils.tensor_setitem_by_tensor_with_tuple(data, index, value)
return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value)


@setitem.register("Tensor", "Tensor", "List")
def _tensor_setitem_by_tensor_with_list(data, index, value):
"""
Tensor assignment.

Inputs:
data (Tensor): Assigned tensor.
index (Tensor): Tensor of bool type.
value (List): Assignment value.

Outputs:
Tensor, element type and shape is same as data.
"""
return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value)


@setitem.register("Tensor", "Slice", "Tensor")
def _tensor_setitem_with_slice_v3(data, input_slice, value):
def _tensor_setitem_by_slice_with_tensor(data, input_slice, value):
"""
Tensor assignment.

@@ -298,7 +349,7 @@ def _tensor_setitem_with_slice_v3(data, input_slice, value):


@setitem.register("Tensor", "Slice", "Number")
def _tensor_setitem_with_slice_v1(data, input_slice, value):
def _tensor_setitem_by_slice_with_number(data, input_slice, value):
"""
Tensor assignment.

@@ -319,21 +370,326 @@ def _tensor_setitem_with_slice_v1(data, input_slice, value):
return compile_utils.tensor_setitem_by_slice_with_number(data, input_slice, value)


@setitem.register("Tensor", "Slice", "List")
def _tensor_setitem_by_slice_with_list(data, input_slice, value):
"""
Tensor assignment.

Note:
Syntax support: A[Slice] = u
Restraint condition: A is a Tensor.
Slice like "1:3"
u is a list

Inputs:
data (Tensor): Assigned tensor.
input_slice (Slice): slice expression.
value (List): Assignment value.

Outputs:
Tensor, element type and shape is same as data.
"""
return compile_utils.tensor_setitem_by_slice_with_sequence(data, input_slice, value)


@setitem.register("Tensor", "Slice", "Tuple")
def _tensor_setitem_by_slice_with_tuple(data, input_slice, value):
"""
Tensor assignment.

Note:
Syntax support: A[Slice] = u
Restraint condition: A is a Tensor.
Slice like "1:3"
u is a tuple

Inputs:
data (Tensor): Assigned tensor.
input_slice (Slice): slice expression.
value (Tuple): Assignment value.

Outputs:
Tensor, element type and shape is same as data.
"""
return compile_utils.tensor_setitem_by_slice_with_sequence(data, input_slice, value)



@setitem.register("Tensor", "Number", "Number")
def _tensor_setitem_with_int_v1(data, index, value):
def _tensor_setitem_by_number_with_number(data, index, value):
"""
Tensor assignment.

Note:
Syntax support: A[Number] = u
Restraint condition: A is a Tensor.
u is a Number.

Inputs:
data (Tensor): Assigned tensor.
index (Number): An integer index.
value (Tuple): Assignment value.

Outputs:
Tensor, element type and shape is same as data.
"""
if isinstance(index, bool):
return compile_utils.tensor_setitem_by_bool(data, index, value)
return compile_utils.tensor_setitem_by_number_with_number(data, index, value)


@setitem.register("Tensor", "Number", "Tensor")
def _tensor_setitem_with_int_v2(data, index, value):
def _tensor_setitem_by_number_with_tensor(data, index, value):
"""
Tensor assignment.

Note:
Syntax support: A[Number] = u
Restraint condition: A is a Tensor.
u is a Tensor.

Inputs:
data (Tensor): Assigned tensor.
index (Number): An integer index.
value (Tensor): Assignment value.

Outputs:
Tensor, element type and shape is same as data.
"""
if isinstance(index, bool):
return compile_utils.tensor_setitem_by_bool(data, index, value)
return compile_utils.tensor_setitem_by_number_with_tensor(data, index, value)


@setitem.register("Tensor", "Number", "Tuple")
def _tensor_setitem_by_number_with_tuple(data, index, value):
"""
Tensor assignment.

Note:
Syntax support: A[Number] = u
Restraint condition: A is a Tensor.
u is a Tuple, with all elements equal in length.

Inputs:
data (Tensor): Assigned tensor.
index (Number): An integer index.
value (Tuple): Assignment value.

Outputs:
Tensor, element type and shape is same as data.
"""
if isinstance(index, bool):
return compile_utils.tensor_setitem_by_bool(data, index, value)
return compile_utils.tensor_setitem_by_number_with_sequence(data, index, value)


@setitem.register("Tensor", "Number", "List")
def _tensor_setitem_by_number_with_list(data, index, value):
"""
Tensor assignment.

Note:
Syntax support: A[Number] = u
Restraint condition: A is a Tensor.
u is a List, with all elements equal in length.

Inputs:
data (Tensor): Assigned tensor.
index (Number): An integer index.
value (List): Assignment value.

Outputs:
Tensor, element type and shape is same as data.
"""
if isinstance(index, bool):
return compile_utils.tensor_setitem_by_bool(data, index, value)
return compile_utils.tensor_setitem_by_number_with_sequence(data, index, value)


@setitem.register("Tensor", "Ellipsis", "Number")
def _tensor_setitem_with_ellipsis_v1(data, index, value):
return compile_utils.tensor_setitem_by_ellipsis_with_number(data, index, value)
def _tensor_setitem_by_ellipsis_with_number(data, index, value):
"""
Tensor assignment.

Note:
Syntax support: A[...] = u
Restraint condition: A is a Tensor.
u is a Number.
Inputs:
data (Tensor): Assigned tensor.
index (Ellipsis): Index is ``...``.
value (Number): Assignment value.

Outputs:
Tensor, element type and shape is same as data.
"""
return compile_utils.tensor_setitem_by_ellipsis_with_number(data, value)


@setitem.register("Tensor", "Ellipsis", "Tensor")
def _tensor_setitem_with_ellipsis_v2(data, index, value):
return compile_utils.tensor_setitem_by_ellipsis_with_tensor(data, index, value)
def _tensor_setitem_by_ellipsis_with_tensor(data, index, value):
"""
Tensor assignment.

Note:
Syntax support: A[...] = u
Restraint condition: A is a Tensor.
u is a Tensor.
Inputs:
data (Tensor): Assigned tensor.
index (Ellipsis): Index is ``...``.
value (Tensor): Assignment value.

Outputs:
Tensor, element type and shape is same as data.
"""
return compile_utils.tensor_setitem_by_ellipsis_with_tensor(data, value)


@setitem.register("Tensor", "Ellipsis", "List")
def _tensor_setitem_by_ellipsis_with_list(data, index, value):
"""
Tensor assignment.

Note:
Syntax support: A[...] = u
Restraint condition: A is a Tensor.
u is a List, with all elements equal in length.
Inputs:
data (Tensor): Assigned tensor.
index (Ellipsis): Index is ``...``.
value (Number): Assignment value.

Outputs:
Tensor, element type and shape is same as data.
"""
return compile_utils.tensor_setitem_by_ellipsis_with_sequence(data, value)


@setitem.register("Tensor", "Ellipsis", "Tuple")
def _tensor_setitem_by_ellipsis_with_tuple(data, index, value):
"""
Tensor assignment.

Note:
Syntax support: A[...] = u
Restraint condition: A is a Tensor.
u is a Tuple, with all elements equal in length.
Inputs:
data (Tensor): Assigned tensor.
index (Ellipsis): Index is ``...``.
value (Number): Assignment value.

Outputs:
Tensor, element type and shape is same as data.
"""
return compile_utils.tensor_setitem_by_ellipsis_with_sequence(data, value)


@setitem.register("Tensor", "List", "Number")
def _tensor_setitem_by_list_with_number(data, index, value):
"""
Tensor assignment.

Note:
Syntax support: A[List] = u
Restraint condition: A is a Tensor.
u is a Number.
Inputs:
data (Tensor): Assigned tensor.
index (List).
value (Number): Assignment value.

Outputs:
Tensor, element type and shape is same as data.
"""
# list indices will be converted to tuple or tensor based on its contents.
index = compile_utils.format_list_indices(index, data.shape[0])
if isinstance(index, Tensor):
return compile_utils.tensor_setitem_by_tensor_with_number(data, index, value)
if compile_utils.tuple_indices_have_false(index):
return data
index = compile_utils.format_tuple_indices(index)
return compile_utils.tensor_setitem_by_tuple_with_number(data, index, value)


@setitem.register("Tensor", "List", "Tensor")
def _tensor_setitem_by_list_with_tensor(data, index, value):
"""
Tensor assignment.

Note:
Syntax support: A[List] = u
Restraint condition: A is a Tensor.
u is a Tensor.
Inputs:
data (Tensor): Assigned tensor.
index (List).
value (Tensor): Assignment value.

Outputs:
Tensor, element type and shape is same as data.
"""
# list indices will be converted to tuple or tensor based on its contents.
index = compile_utils.format_list_indices(index, data.shape[0])
if isinstance(index, Tensor):
return compile_utils.tensor_setitem_by_tensor_with_tensor(data, index, value)
if compile_utils.tuple_indices_have_false(index):
return data
index = compile_utils.format_tuple_indices(index)
return compile_utils.tensor_setitem_by_tuple_with_tensor(data, index, value)


@setitem.register("Tensor", "List", "Tuple")
def _tensor_setitem_by_list_with_tuple(data, index, value):
"""
Tensor assignment.

Note:
Syntax support: A[List] = u
Restraint condition: A is a Tensor.
u is a Tuple, with all elements equal in length.
Inputs:
data (Tensor): Assigned tensor.
index (List).
value (Tuple): Assignment value.

Outputs:
Tensor, element type and shape is same as data.
"""
# list indices will be converted to tuple or tensor based on its contents.
index = compile_utils.format_list_indices(index, data.shape[0])
if isinstance(index, Tensor):
return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value)
if compile_utils.tuple_indices_have_false(index):
return data
index = compile_utils.format_tuple_indices(index)
return compile_utils.tensor_setitem_by_tuple_with_sequence(data, index, value)


@setitem.register("Tensor", "List", "List")
def _tensor_setitem_by_list_with_list(data, index, value):
"""
Tensor assignment.

Note:
Syntax support: A[List] = u
Restraint condition: A is a Tensor.
u is a List, with all elements equal in length.
Inputs:
data (Tensor): Assigned tensor.
index (List).
value (List): Assignment value.

Outputs:
Tensor, element type and shape is same as data.
"""
# list indices will be converted to tuple or tensor based on its contents.
index = compile_utils.format_list_indices(index, data.shape[0])
if isinstance(index, Tensor):
return compile_utils.tensor_setitem_by_tensor_with_sequence(data, index, value)
if compile_utils.tuple_indices_have_false(index):
return data
index = compile_utils.format_tuple_indices(index)
return compile_utils.tensor_setitem_by_tuple_with_sequence(data, index, value)

+ 3
- 3
tests/st/pynative/test_tensor_index.py View File

@@ -321,7 +321,7 @@ def test_setitem_by_mixed_tensors_2():
assert np.all(out.asnumpy() == (input_np + const))


class TensorGetItemByMixedTensorsTypeError(Cell):
class TensorGetItemByMixedTensorsIndexError(Cell):
def construct(self, x, index_0, index_1):
ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]]
return ret
@@ -331,8 +331,8 @@ def test_getitem_by_mixedtensor_exception():
input_ms = Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32)
index_0 = Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32)
index_1 = Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)
net1 = TensorGetItemByMixedTensorsTypeError()
with pytest.raises(TypeError):
net1 = TensorGetItemByMixedTensorsIndexError()
with pytest.raises(IndexError):
net1(input_ms, index_0, index_1)




+ 215
- 0
tests/st/pynative/test_tensor_setitem.py View File

@@ -0,0 +1,215 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_tensor_setitem """
import numpy as onp
import pytest

from mindspore import Tensor, context
from mindspore.nn import Cell


def setup_module():
context.set_context(mode=context.GRAPH_MODE)


def setup_testcase(input_np, case_fn):
input_ms = Tensor(input_np)

class TensorSetItem(Cell):
def construct(self, x):
return case_fn(x)

class NumpySetItem():
def __call__(self, x):
return case_fn(x)

out_ms = TensorSetItem()(input_ms)
out_np = NumpySetItem()(input_np)
assert onp.all(out_ms.asnumpy() == out_np)


class TensorSetItemByList(Cell):
def construct(self, x):
x[[0, 1], [1, 2], [1, 3]] = [3, 4]
x[([0, 1], [0, 2], [1, 1])] = [10, 5]
x[[0, 1], ..., [0, 1]] = 4
return x

class NumpySetItemByList():
def __call__(self, x):
x[[0, 1], [1, 2], [1, 3]] = [3, 4]
x[([0, 1], [0, 2], [1, 1])] = [10, 5]
x[[0, 1], ..., [0, 1]] = 4
return x

@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_by_list():
x = onp.ones((2, 3, 4), dtype=onp.float32)
def cases(x):
x[[0, 1], [1, 2], [1, 3]] = [3, 4]
x[([0, 1], [0, 2], [1, 1])] = [10, 5]
x[[0, 1], ..., [0, 1]] = 4
return x
setup_testcase(x, cases)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_with_sequence():
x = onp.ones((2, 3, 4), dtype=onp.float32)
def cases(x):
x[...] = [3]
x[..., 1] = ([1, 2, 3], [4, 5, 6])
x[0] = ((0, 1, 2, 3), (4, 5, 6, 7), [8, 9, 10, 11])
x[1:2] = ((0, 1, 2, 3), (4, 5, 6, 7), [8, 9, 10, 11])
return x
setup_testcase(x, cases)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_dtype():
x = onp.ones((2, 3, 4), dtype=onp.float32)
def cases(x):
x[...] = 3
x[..., 1] = 3.0
x[0] = True
x[1:2] = ((0, False, 2, 3), (4.0, 5, 6, 7), [True, 9, 10, 11])
return x
setup_testcase(x, cases)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_by_tuple_with_int():
x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32)
def cases(x):
x[..., 2, False, 1] = -1
x[0, True, 0, None, True] = -2
x[0, ..., None] = -3
x[..., 0, None, 1, True, True, None] = -4
return x
setup_testcase(x, cases)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_by_tuple_with_list():
x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32)
def cases(x):
x[..., 2, False, 1] = [-1]
x[0, True, 0, None, True] = [-2, -2, -2, -2]
x[0, ..., None] = [[-3], [-3], [-3], [-3]]
x[..., 0, None, 1, True, True, None] = [[[-4]], [[-4]]]
return x
setup_testcase(x, cases)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_by_nested_unit_list():
x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32)
def cases(x):
x[[[[0]]], True] = -1
x[[1], ..., [[[[2]]]]] = -2
x[0, [[[2]]], [1]] = -3
return x
setup_testcase(x, cases)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_with_broadcast():
x = onp.arange(2*3*4*5*6).reshape(2, 3, 4, 5, 6).astype(onp.float32)
v1 = onp.full((1, 4, 5), -1).tolist()
v2 = onp.full((4, 1, 6), -2).tolist()
def cases(x):
x[..., 4] = v1
x[0, 2] = v2
x[1, 0, ..., 3] = [[-3], [-3], [-3], [-3]]
x[0, ..., 1, 3, 5] = -4
return x
setup_testcase(x, cases)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_mul_by_scalar():
x = onp.ones((4, 5), dtype=onp.float32)
def cases(x):
x[1, :] = x[1, :]*2
x[:, 2] = x[:, 3]*3.0
return x
setup_testcase(x, cases)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_by_slice():
x = onp.ones((3, 4, 5), dtype=onp.float32)
def cases(x):
x[1:2] = 2
x[-3:1] = 3
x[-10:3:2] = 4
x[5:0:3] = 5
x[5:5:5] = 6
x[-1:2] = 7
return x
setup_testcase(x, cases)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_setitem_by_tuple_of_slices():
x = onp.ones((3, 4, 5), dtype=onp.float32)
def cases(x):
x[1:2, 2] = 2
x[0, -4:1] = 3
x[1, -10:3:2] = 4
x[5:0:3, 3] = 5
x[1:1, 2:2] = 6
return x
setup_testcase(x, cases)

Loading…
Cancel
Save