|
|
@@ -23,7 +23,7 @@ from ....common import dtype as mstype |
|
|
from ....common._register_for_tensor import tensor_operator_registry |
|
|
from ....common._register_for_tensor import tensor_operator_registry |
|
|
|
|
|
|
|
|
hyper_map = base.HyperMap() |
|
|
hyper_map = base.HyperMap() |
|
|
pack = P.Stack(axis=-1) |
|
|
|
|
|
|
|
|
stack = P.Stack(axis=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tensor_getitem(self, index): |
|
|
def _tensor_getitem(self, index): |
|
|
@@ -36,44 +36,35 @@ def _tensor_getitem(self, index): |
|
|
return tensor_index_by_tuple(self, (index,)) |
|
|
return tensor_index_by_tuple(self, (index,)) |
|
|
raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor with int, " |
|
|
raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor with int, " |
|
|
f"list and tuple ,but got {index} with type {type(index)}.") |
|
|
f"list and tuple ,but got {index} with type {type(index)}.") |
|
|
|
|
|
|
|
|
|
|
|
tensor_operator_registry.register("__getitem__", _tensor_getitem) |
|
|
|
|
|
|
|
|
def _tensor_setitem(self, index, value): |
|
|
def _tensor_setitem(self, index, value): |
|
|
"""Handle tensor getitem""" |
|
|
|
|
|
|
|
|
"""Handle tensor setitem""" |
|
|
|
|
|
if not isinstance(value, (int, float, bool, list, tuple, Tensor)): |
|
|
|
|
|
raise ValueError(f"only support numbers, Tensor, tuple, list as value," |
|
|
|
|
|
f"but got {value} with type {type(value)}.") |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(index, list): |
|
|
|
|
|
index = format_list_indices(index, self.shape[0]) |
|
|
if isinstance(index, Tensor): |
|
|
if isinstance(index, Tensor): |
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
|
|
return tensor_setitem_by_tensor_with_number(self, index, value) |
|
|
|
|
|
if isinstance(value, Tensor): |
|
|
|
|
|
return tensor_setitem_by_tensor_with_tensor(self, index, value) |
|
|
|
|
|
if isinstance(value, tuple): |
|
|
|
|
|
return tensor_setitem_by_tensor_with_tuple(self, index, value) |
|
|
|
|
|
|
|
|
return tensor_setitem_by_tensor(self, index, value) |
|
|
if isinstance(index, tuple): |
|
|
if isinstance(index, tuple): |
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
|
|
return tensor_setitem_by_tuple_with_number(self, index, value) |
|
|
|
|
|
if isinstance(value, Tensor): |
|
|
|
|
|
return tensor_setitem_by_tuple_with_tensor(self, index, value) |
|
|
|
|
|
if isinstance(value, tuple): |
|
|
|
|
|
return tensor_setitem_by_tuple_with_tuple(self, index, value) |
|
|
|
|
|
|
|
|
if tuple_indices_have_false(index): |
|
|
|
|
|
return self |
|
|
|
|
|
index = format_tuple_indices(index) |
|
|
|
|
|
return tensor_setitem_by_tuple(self, index, value) |
|
|
|
|
|
if isinstance(index, bool): |
|
|
|
|
|
return tensor_setitem_by_bool(self, index, value) |
|
|
if isinstance(index, int): |
|
|
if isinstance(index, int): |
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
|
|
return tensor_setitem_by_number_with_number(self, index, value) |
|
|
|
|
|
if isinstance(value, Tensor): |
|
|
|
|
|
return tensor_setitem_by_number_with_tensor(self, index, value) |
|
|
|
|
|
|
|
|
return tensor_setitem_by_number(self, index, value) |
|
|
if isinstance(index, slice): |
|
|
if isinstance(index, slice): |
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
|
|
return tensor_setitem_by_slice_with_number(self, index, value) |
|
|
|
|
|
if isinstance(value, Tensor): |
|
|
|
|
|
return tensor_setitem_by_slice_with_tensor(self, index, value) |
|
|
|
|
|
if isinstance(index, bool): |
|
|
|
|
|
return _tensor_index_by_bool(self, index) |
|
|
|
|
|
|
|
|
return tensor_setitem_by_slice(self, index, value) |
|
|
if index is ...: |
|
|
if index is ...: |
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
|
|
return tensor_setitem_by_ellipsis_with_number(self, index, value) |
|
|
|
|
|
if isinstance(value, Tensor): |
|
|
|
|
|
return tensor_setitem_by_ellipsis_with_tensor(self, index, value) |
|
|
|
|
|
|
|
|
return tensor_setitem_by_ellipsis(self, index, value) |
|
|
|
|
|
|
|
|
raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\ |
|
|
raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\ |
|
|
and tensor with int32, got {} with type{}".format(index, type(index))) |
|
|
and tensor with int32, got {} with type{}".format(index, type(index))) |
|
|
|
|
|
|
|
|
|
|
|
tensor_operator_registry.register("__setitem__", _tensor_setitem) |
|
|
|
|
|
|
|
|
def _broadcast(broadcast_shape, x): |
|
|
def _broadcast(broadcast_shape, x): |
|
|
"""Broadcast tensor to the required shape.""" |
|
|
"""Broadcast tensor to the required shape.""" |
|
|
@@ -103,7 +94,8 @@ def _transform_ellipsis_to_slice(data, tuple_index, op_name): |
|
|
ellipsis_occupy_dims = data_rank - (len(slice_positions) + len(int_positions) + |
|
|
ellipsis_occupy_dims = data_rank - (len(slice_positions) + len(int_positions) + |
|
|
len(tensor_positions) + len(sequence_positions)) |
|
|
len(tensor_positions) + len(sequence_positions)) |
|
|
ellipsis_cnt = len(ellipsis_positions) |
|
|
ellipsis_cnt = len(ellipsis_positions) |
|
|
if (ellipsis_cnt == 0 and ellipsis_occupy_dims < 0) or (ellipsis_cnt > 0 and ellipsis_occupy_dims < 1): |
|
|
|
|
|
|
|
|
# pylint: disable=chained-comparison |
|
|
|
|
|
if ellipsis_occupy_dims < 0 and ellipsis_cnt >= 0: |
|
|
const_utils.raise_index_error("For the 'getitem Operator', the data_shape should be no less than the " |
|
|
const_utils.raise_index_error("For the 'getitem Operator', the data_shape should be no less than the " |
|
|
"tuple index dims") |
|
|
"tuple index dims") |
|
|
|
|
|
|
|
|
@@ -155,14 +147,6 @@ def tensor_index_by_number(data, number_index): |
|
|
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.") |
|
|
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO wait to remove after setitem by Yang Linfeng |
|
|
|
|
|
def _tensor_index_by_bool(data, bool_index): |
|
|
|
|
|
"""Tensor getitem by a single bool value""" |
|
|
|
|
|
if bool_index: |
|
|
|
|
|
return F.expand_dims(data, 0) |
|
|
|
|
|
return const_utils.make_tensor([], data.dtype, (0,) + F.shape(data)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tensor_index_by_integer(data, int_index): |
|
|
def _tensor_index_by_integer(data, int_index): |
|
|
"""Tensor getitem by a single integer number""" |
|
|
"""Tensor getitem by a single integer number""" |
|
|
data_shape = F.shape(data) |
|
|
data_shape = F.shape(data) |
|
|
@@ -218,6 +202,31 @@ def tensor_index_by_tuple(data, tuple_index): |
|
|
return _tensor_getitem_by_tuple(data, tuple_index, op_name) |
|
|
return _tensor_getitem_by_tuple(data, tuple_index, op_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name): |
|
|
|
|
|
"""Tensor getitem by a tuple of tensor.""" |
|
|
|
|
|
data_shape = F.shape(data) |
|
|
|
|
|
tuple_index_len = len(tuple_index) |
|
|
|
|
|
|
|
|
|
|
|
indexes_types = hyper_map(F.dtype, tuple_index) |
|
|
|
|
|
const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name) |
|
|
|
|
|
tensor_index_shape = hyper_map(F.shape, tuple_index) |
|
|
|
|
|
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name) |
|
|
|
|
|
if 0 in broadcast_shape: |
|
|
|
|
|
res_shape = broadcast_shape |
|
|
|
|
|
if tuple_index_len < len(data_shape): |
|
|
|
|
|
res_shape += data_shape[tuple_index_len:] |
|
|
|
|
|
res = const_utils.make_tensor([], data.dtype, res_shape) |
|
|
|
|
|
return res |
|
|
|
|
|
|
|
|
|
|
|
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index) |
|
|
|
|
|
new_broadcast_tensors = () |
|
|
|
|
|
for tensor in broadcast_tensors: |
|
|
|
|
|
new_broadcast_tensors += (F.cast(tensor, mstype.int64),) |
|
|
|
|
|
indices = stack(new_broadcast_tensors) |
|
|
|
|
|
result = F.gather_nd(data, indices) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tensor_getitem_by_tuple_slice(data, tuple_index): |
|
|
def _tensor_getitem_by_tuple_slice(data, tuple_index): |
|
|
"""Tensor getitem by a tuple of slice""" |
|
|
"""Tensor getitem by a tuple of slice""" |
|
|
data_shape = F.shape(data) |
|
|
data_shape = F.shape(data) |
|
|
@@ -284,8 +293,9 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name): |
|
|
final_index_tensors.append(slice_index_tensor) |
|
|
final_index_tensors.append(slice_index_tensor) |
|
|
slice_cnt += 1 |
|
|
slice_cnt += 1 |
|
|
|
|
|
|
|
|
indices = pack(final_index_tensors) |
|
|
|
|
|
return F.gather_nd(data, indices) |
|
|
|
|
|
|
|
|
indices = stack(final_index_tensors) |
|
|
|
|
|
result = F.gather_nd(data, indices) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_indices_from_tuple_of_tensor(tuple_index, op_name): |
|
|
def _generate_indices_from_tuple_of_tensor(tuple_index, op_name): |
|
|
@@ -299,7 +309,7 @@ def _generate_indices_from_tuple_of_tensor(tuple_index, op_name): |
|
|
new_broadcast_tensors = () |
|
|
new_broadcast_tensors = () |
|
|
for tensor in broadcast_tensors: |
|
|
for tensor in broadcast_tensors: |
|
|
new_broadcast_tensors += (F.cast(tensor, mstype.int64),) |
|
|
new_broadcast_tensors += (F.cast(tensor, mstype.int64),) |
|
|
indices = pack(new_broadcast_tensors) |
|
|
|
|
|
|
|
|
indices = stack(new_broadcast_tensors) |
|
|
return indices |
|
|
return indices |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -332,6 +342,11 @@ def _generate_indices_from_tuple(data, tuple_index, op_name): |
|
|
tuple_index_new += (tensor_index,) |
|
|
tuple_index_new += (tensor_index,) |
|
|
tensor_indexes.append(tensor_index) |
|
|
tensor_indexes.append(tensor_index) |
|
|
elif i in slice_positions: |
|
|
elif i in slice_positions: |
|
|
|
|
|
start, stop, _ = const_utils.slice_to_tuple(index) |
|
|
|
|
|
start = const_utils.normalize_start(start, dim_size) |
|
|
|
|
|
stop = const_utils.normalize_stop(stop, dim_size) |
|
|
|
|
|
if start >= stop: |
|
|
|
|
|
return None |
|
|
slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size) |
|
|
slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size) |
|
|
slice_shapes += (len(slice_ele_list_index),) |
|
|
slice_shapes += (len(slice_ele_list_index),) |
|
|
tuple_index_new += (slice_ele_list_index,) |
|
|
tuple_index_new += (slice_ele_list_index,) |
|
|
@@ -354,7 +369,7 @@ def _generate_indices_from_tuple(data, tuple_index, op_name): |
|
|
final_index_tensors.append(slice_index_tensor) |
|
|
final_index_tensors.append(slice_index_tensor) |
|
|
slice_cnt += 1 |
|
|
slice_cnt += 1 |
|
|
|
|
|
|
|
|
indices = pack(final_index_tensors) |
|
|
|
|
|
|
|
|
indices = stack(final_index_tensors) |
|
|
return indices |
|
|
return indices |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -366,44 +381,76 @@ def _generate_updates_from_scalar(data, indices, value, op_type): |
|
|
return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type) |
|
|
return const_utils.convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_updates_from_tuple(data, index, value, op_type): |
|
|
|
|
|
"""Generate an updates tensor from a tuple.""" |
|
|
|
|
|
|
|
|
def _generate_updates_from_sequence(data, index, value, op_type): |
|
|
|
|
|
"""Generate an updates tensor from a tuple, can only handle 1-D tensor/non-tensor mixtures.""" |
|
|
value_types = hyper_map(F.typeof, value) |
|
|
value_types = hyper_map(F.typeof, value) |
|
|
data_dtype = F.dtype(data) |
|
|
|
|
|
value_elements_type = const_utils.check_value_elements(data_dtype, value_types) |
|
|
|
|
|
|
|
|
value_elements_type = const_utils.check_value_elements(value_types) |
|
|
|
|
|
|
|
|
if value_elements_type == const_utils.ALL_TENSOR: |
|
|
if value_elements_type == const_utils.ALL_TENSOR: |
|
|
value_shapes = hyper_map(F.shape, value) |
|
|
|
|
|
shapes_same = const_utils.check_shapes_same(value_shapes, const_utils.TENSOR_SETITEM) |
|
|
|
|
|
if shapes_same: |
|
|
|
|
|
value = F.stack(value) |
|
|
|
|
|
return _generate_updates_from_tensor(data, index, value, op_type) |
|
|
|
|
|
|
|
|
|
|
|
data_shape = F.shape(data) |
|
|
|
|
|
index_shape = F.shape(index) |
|
|
|
|
|
return const_utils.convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type) |
|
|
|
|
|
|
|
|
value = F.stack(value).astype(data.dtype) |
|
|
|
|
|
elif value_elements_type == const_utils.NO_TENSOR: |
|
|
|
|
|
value = const_utils.make_tensor(value, data.dtype) |
|
|
|
|
|
else: |
|
|
|
|
|
new_value = () |
|
|
|
|
|
for ele in value: |
|
|
|
|
|
ele = ele if isinstance(ele, Tensor) else const_utils.make_tensor(ele) |
|
|
|
|
|
new_value += (ele,) |
|
|
|
|
|
value = F.stack(new_value).astype(data.dtype) |
|
|
|
|
|
if op_type == const_utils.SET_ITEM_BY_NON_TENSOR: |
|
|
|
|
|
return value |
|
|
|
|
|
return _generate_updates_from_tensor(data, index, value, op_type) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_updates_from_tensor(data, index, value, op_type): |
|
|
def _generate_updates_from_tensor(data, index, value, op_type): |
|
|
"""Generate an updates tensor from a tensor.""" |
|
|
"""Generate an updates tensor from a tensor.""" |
|
|
data_shape = F.shape(data) |
|
|
|
|
|
index_shape = F.shape(index) |
|
|
|
|
|
value_shape = F.shape(value) |
|
|
|
|
|
data_dtype = F.dtype(data) |
|
|
|
|
|
value_dtype = F.dtype(value) |
|
|
|
|
|
updates_shape = value_shape |
|
|
|
|
|
check_dtype_same = const_utils.check_tensors_dtype_same(data_dtype, value_dtype, const_utils.TENSOR_SETITEM) |
|
|
|
|
|
if check_dtype_same: |
|
|
|
|
|
updates_shape = const_utils.generate_updates_shape(data_shape, index_shape, op_type) |
|
|
|
|
|
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value_shape) |
|
|
|
|
|
|
|
|
value = value.astype(data.dtype) |
|
|
|
|
|
updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type) |
|
|
|
|
|
need_broadcast = const_utils.check_two_shapes_need_broadcast(updates_shape, value.shape) |
|
|
if need_broadcast: |
|
|
if need_broadcast: |
|
|
return _broadcast(updates_shape, value) |
|
|
return _broadcast(updates_shape, value) |
|
|
return value |
|
|
return value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tensor_operator_registry.register("__getitem__", _tensor_getitem) |
|
|
|
|
|
|
|
|
# Tensor getitem implementations are above this line, setitem implementations below. |
|
|
|
|
|
|
|
|
tensor_operator_registry.register("__setitem__", _tensor_setitem) |
|
|
|
|
|
|
|
|
def tensor_setitem_by_tensor(self, index, value): |
|
|
|
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
|
|
return tensor_setitem_by_tensor_with_number(self, index, value) |
|
|
|
|
|
if isinstance(value, Tensor): |
|
|
|
|
|
return tensor_setitem_by_tensor_with_tensor(self, index, value) |
|
|
|
|
|
return tensor_setitem_by_tensor_with_sequence(self, index, value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_tuple(self, index, value): |
|
|
|
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
|
|
return tensor_setitem_by_tuple_with_number(self, index, value) |
|
|
|
|
|
if isinstance(value, Tensor): |
|
|
|
|
|
return tensor_setitem_by_tuple_with_tensor(self, index, value) |
|
|
|
|
|
return tensor_setitem_by_tuple_with_sequence(self, index, value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_number(self, index, value): |
|
|
|
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
|
|
return tensor_setitem_by_number_with_number(self, index, value) |
|
|
|
|
|
if isinstance(value, Tensor): |
|
|
|
|
|
return tensor_setitem_by_number_with_tensor(self, index, value) |
|
|
|
|
|
return tensor_setitem_by_number_with_sequence(self, index, value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_slice(self, index, value): |
|
|
|
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
|
|
return tensor_setitem_by_slice_with_number(self, index, value) |
|
|
|
|
|
if isinstance(value, Tensor): |
|
|
|
|
|
return tensor_setitem_by_slice_with_tensor(self, index, value) |
|
|
|
|
|
return tensor_setitem_by_slice_with_sequence(self, index, value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_ellipsis(self, index, value): |
|
|
|
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
|
|
return tensor_setitem_by_ellipsis_with_number(self, value) |
|
|
|
|
|
if isinstance(value, Tensor): |
|
|
|
|
|
return tensor_setitem_by_ellipsis_with_tensor(self, value) |
|
|
|
|
|
return tensor_setitem_by_ellipsis_with_sequence(self, value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): |
|
|
def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): |
|
|
@@ -469,17 +516,16 @@ def tensor_setitem_by_tensor_with_number(data, index, value): |
|
|
return const_utils.raise_index_error("For tensor setitem, indexing tensor dtype only supports bool/int") |
|
|
return const_utils.raise_index_error("For tensor setitem, indexing tensor dtype only supports bool/int") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_tensor_with_tuple(data, index, value): |
|
|
|
|
|
|
|
|
def tensor_setitem_by_tensor_with_sequence(data, index, value): |
|
|
"""Assigns the tensor by tensor with tuple value.""" |
|
|
"""Assigns the tensor by tensor with tuple value.""" |
|
|
index_dtype = F.dtype(index) |
|
|
index_dtype = F.dtype(index) |
|
|
const_utils.check_type_valid(index_dtype, (mstype.int32, mstype.int64), const_utils.TENSOR_SETITEM) |
|
|
const_utils.check_type_valid(index_dtype, (mstype.int32, mstype.int64), const_utils.TENSOR_SETITEM) |
|
|
result = _tensor_setitem_by_tensor_with_tuple(data, index, value) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
return _tensor_setitem_by_tensor_with_sequence(data, index, value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tensor_indices_number(data, data_shape, index, indices, value): |
|
|
def _tensor_indices_number(data, data_shape, index, indices, value): |
|
|
"""Assigns a scalar value to the tensor.""" |
|
|
"""Assigns a scalar value to the tensor.""" |
|
|
data_size = F.size(data) |
|
|
|
|
|
|
|
|
data_size = F.shape_mul(data.shape) |
|
|
data_dtype = F.dtype(data) |
|
|
data_dtype = F.dtype(data) |
|
|
indices_size = F.size(indices) |
|
|
indices_size = F.size(indices) |
|
|
indices_size = const_utils.check_indices(indices_size, index) |
|
|
indices_size = const_utils.check_indices(indices_size, index) |
|
|
@@ -493,9 +539,9 @@ def _tensor_indices_number(data, data_shape, index, indices, value): |
|
|
return F.select(condition, u, data) |
|
|
return F.select(condition, u, data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tensor_setitem_by_tensor_with_tuple(data, index, value): |
|
|
|
|
|
|
|
|
def _tensor_setitem_by_tensor_with_sequence(data, index, value): |
|
|
"""Set a tensor item by a tensor with a tuple.""" |
|
|
"""Set a tensor item by a tensor with a tuple.""" |
|
|
updates = _generate_updates_from_tuple(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) |
|
|
|
|
|
|
|
|
updates = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) |
|
|
index = F.expand_dims(index, -1) |
|
|
index = F.expand_dims(index, -1) |
|
|
return P.TensorScatterUpdate()(data, index, updates) |
|
|
return P.TensorScatterUpdate()(data, index, updates) |
|
|
|
|
|
|
|
|
@@ -507,6 +553,8 @@ def tensor_setitem_by_slice_with_number(data, input_slice, value): |
|
|
if check_result: |
|
|
if check_result: |
|
|
data_shape = F.shape(data) |
|
|
data_shape = F.shape(data) |
|
|
indices = const_utils.slice2indices(input_slice, data_shape) |
|
|
indices = const_utils.slice2indices(input_slice, data_shape) |
|
|
|
|
|
if indices is False: |
|
|
|
|
|
return data |
|
|
is_tuple_int = const_utils.tuple_element_is_int(input_slice) |
|
|
is_tuple_int = const_utils.tuple_element_is_int(input_slice) |
|
|
if is_tuple_int: |
|
|
if is_tuple_int: |
|
|
indices = const_utils.integer_to_indices(input_slice, data_shape) |
|
|
indices = const_utils.integer_to_indices(input_slice, data_shape) |
|
|
@@ -516,6 +564,8 @@ def tensor_setitem_by_slice_with_number(data, input_slice, value): |
|
|
|
|
|
|
|
|
def tensor_setitem_by_tuple_with_number(data, tuple_index, value): |
|
|
def tensor_setitem_by_tuple_with_number(data, tuple_index, value): |
|
|
"""Assigns the tensor by tuple with number value.""" |
|
|
"""Assigns the tensor by tuple with number value.""" |
|
|
|
|
|
tuple_index = ignore_dim_expand(tuple_index) |
|
|
|
|
|
|
|
|
if len(tuple_index) == 1: |
|
|
if len(tuple_index) == 1: |
|
|
data[tuple_index[0]] = value |
|
|
data[tuple_index[0]] = value |
|
|
return data |
|
|
return data |
|
|
@@ -533,13 +583,15 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value): |
|
|
if int_cnt == const_utils.ALL_INT: |
|
|
if int_cnt == const_utils.ALL_INT: |
|
|
tuple_index = const_utils.convert_int_to_slice(tuple_index) |
|
|
tuple_index = const_utils.convert_int_to_slice(tuple_index) |
|
|
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) |
|
|
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) |
|
|
|
|
|
if indices is None: |
|
|
|
|
|
return data |
|
|
updates = _generate_updates_from_scalar(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) |
|
|
updates = _generate_updates_from_scalar(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) |
|
|
return P.TensorScatterUpdate()(data, indices, updates) |
|
|
return P.TensorScatterUpdate()(data, indices, updates) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tensor_indices_tensor(data, data_shape, index, indices, value): |
|
|
def _tensor_indices_tensor(data, data_shape, index, indices, value): |
|
|
"""Assigns a tensor value to the tensor.""" |
|
|
"""Assigns a tensor value to the tensor.""" |
|
|
data_size = F.size(data) |
|
|
|
|
|
|
|
|
data_size = F.shape_mul(data.shape) |
|
|
data_dtype = F.dtype(data) |
|
|
data_dtype = F.dtype(data) |
|
|
indices_size = F.size(indices) |
|
|
indices_size = F.size(indices) |
|
|
indices_size = const_utils.check_indices(indices_size, index) |
|
|
indices_size = const_utils.check_indices(indices_size, index) |
|
|
@@ -548,7 +600,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value): |
|
|
condition = F.reshape(condition_1d, data_shape) |
|
|
condition = F.reshape(condition_1d, data_shape) |
|
|
condition = F.cast(condition, mstype.bool_) |
|
|
condition = F.cast(condition, mstype.bool_) |
|
|
value_fill = None |
|
|
value_fill = None |
|
|
value_size = F.size(value) |
|
|
|
|
|
|
|
|
value_size = value.size |
|
|
|
|
|
|
|
|
value_size = const_utils.check_indices_value_size(indices_size, value_size) |
|
|
value_size = const_utils.check_indices_value_size(indices_size, value_size) |
|
|
if value_size == 1: |
|
|
if value_size == 1: |
|
|
@@ -559,7 +611,7 @@ def _tensor_indices_tensor(data, data_shape, index, indices, value): |
|
|
value_fill = F.reshape(value, (indices_size,)) |
|
|
value_fill = F.reshape(value, (indices_size,)) |
|
|
value_1d = F.scatter_nd(indices, value_fill, (data_size,)) |
|
|
value_1d = F.scatter_nd(indices, value_fill, (data_size,)) |
|
|
u = F.reshape(value_1d, data_shape) |
|
|
u = F.reshape(value_1d, data_shape) |
|
|
return F.select(condition, u, data) |
|
|
|
|
|
|
|
|
return F.select(condition, u.astype(data_dtype), data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_slice_with_tensor(data, input_slice, value): |
|
|
def tensor_setitem_by_slice_with_tensor(data, input_slice, value): |
|
|
@@ -569,6 +621,8 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value): |
|
|
if check_result: |
|
|
if check_result: |
|
|
data_shape = F.shape(data) |
|
|
data_shape = F.shape(data) |
|
|
indices = const_utils.slice2indices(input_slice, data_shape) |
|
|
indices = const_utils.slice2indices(input_slice, data_shape) |
|
|
|
|
|
if indices is False: |
|
|
|
|
|
return data |
|
|
is_tuple_int = const_utils.tuple_element_is_int(input_slice) |
|
|
is_tuple_int = const_utils.tuple_element_is_int(input_slice) |
|
|
if is_tuple_int: |
|
|
if is_tuple_int: |
|
|
indices = const_utils.integer_to_indices(input_slice, data_shape) |
|
|
indices = const_utils.integer_to_indices(input_slice, data_shape) |
|
|
@@ -576,8 +630,18 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value): |
|
|
return result |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_slice_with_sequence(data, input_slice, value): |
|
|
|
|
|
"""Assigns a list/tuple value to the tensor by slice.""" |
|
|
|
|
|
value = _generate_updates_from_sequence(data, input_slice, value, const_utils.SET_ITEM_BY_NON_TENSOR) |
|
|
|
|
|
return tensor_setitem_by_slice_with_tensor(data, input_slice, value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): |
|
|
def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): |
|
|
"""Assigns the tensor by tuple with tensor value.""" |
|
|
"""Assigns the tensor by tuple with tensor value.""" |
|
|
|
|
|
value_shape = remove_ignored_dim(tuple_index, F.shape(value), F.rank(data)) |
|
|
|
|
|
value = F.reshape(value, value_shape) |
|
|
|
|
|
tuple_index = ignore_dim_expand(tuple_index) |
|
|
|
|
|
|
|
|
if len(tuple_index) == 1: |
|
|
if len(tuple_index) == 1: |
|
|
data[tuple_index[0]] = value |
|
|
data[tuple_index[0]] = value |
|
|
return data |
|
|
return data |
|
|
@@ -600,31 +664,15 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): |
|
|
new_shape += value.shape |
|
|
new_shape += value.shape |
|
|
value = F.reshape(value, new_shape) |
|
|
value = F.reshape(value, new_shape) |
|
|
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) |
|
|
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) |
|
|
|
|
|
if indices is None: |
|
|
|
|
|
return data |
|
|
updates = _generate_updates_from_tensor(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) |
|
|
updates = _generate_updates_from_tensor(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) |
|
|
return P.TensorScatterUpdate()(data, indices, updates) |
|
|
return P.TensorScatterUpdate()(data, indices, updates) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): |
|
|
|
|
|
"""Assigns the tensor by tuple with tuple of value.""" |
|
|
|
|
|
if len(tuple_index) == 1: |
|
|
|
|
|
data[tuple_index[0]] = value |
|
|
|
|
|
return data |
|
|
|
|
|
op_name = const_utils.TENSOR_GETITEM |
|
|
|
|
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) |
|
|
|
|
|
data, tuple_index = _expand_data_dims(data, tuple_index) |
|
|
|
|
|
|
|
|
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
|
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
|
|
|
|
|
|
|
|
if contain_type == const_utils.ALL_TENSOR: |
|
|
|
|
|
indices = _generate_indices_from_tuple_of_tensor(tuple_index, const_utils.TENSOR_SETITEM) |
|
|
|
|
|
else: |
|
|
|
|
|
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
|
|
if int_cnt == const_utils.ALL_INT: |
|
|
|
|
|
tuple_index = const_utils.convert_int_to_slice(tuple_index) |
|
|
|
|
|
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) |
|
|
|
|
|
updates = _generate_updates_from_tuple(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) |
|
|
|
|
|
return P.TensorScatterUpdate()(data, indices, updates) |
|
|
|
|
|
|
|
|
def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value): |
|
|
|
|
|
value = _generate_updates_from_sequence(data, tuple_index, value, const_utils.SET_ITEM_BY_NON_TENSOR) |
|
|
|
|
|
return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_number_with_number(data, index, value): |
|
|
def tensor_setitem_by_number_with_number(data, index, value): |
|
|
@@ -634,6 +682,12 @@ def tensor_setitem_by_number_with_number(data, index, value): |
|
|
return _tensor_indices_number(data, data_shape, index, indices, value) |
|
|
return _tensor_indices_number(data, data_shape, index, indices, value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_number_with_sequence(data, index, value): |
|
|
|
|
|
"""Assigns a list/tuple value to the tensor by slice.""" |
|
|
|
|
|
value = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_NON_TENSOR) |
|
|
|
|
|
return tensor_setitem_by_number_with_tensor(data, index, value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_number_with_tensor(data, index, value): |
|
|
def tensor_setitem_by_number_with_tensor(data, index, value): |
|
|
"""Assigns the tensor by number with tensor value.""" |
|
|
"""Assigns the tensor by number with tensor value.""" |
|
|
data_shape = F.shape(data) |
|
|
data_shape = F.shape(data) |
|
|
@@ -641,31 +695,46 @@ def tensor_setitem_by_number_with_tensor(data, index, value): |
|
|
return _tensor_indices_tensor(data, data_shape, index, indices, value) |
|
|
return _tensor_indices_tensor(data, data_shape, index, indices, value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_ellipsis_with_number(data, index, value): |
|
|
|
|
|
|
|
|
def tensor_setitem_by_ellipsis_with_number(data, value): |
|
|
"""Assigns the tensor by ellipsis with number value.""" |
|
|
"""Assigns the tensor by ellipsis with number value.""" |
|
|
data_shape = F.shape(data) |
|
|
data_shape = F.shape(data) |
|
|
data_dtype = F.dtype(data) |
|
|
data_dtype = F.dtype(data) |
|
|
return F.fill(data_dtype, data_shape, value) |
|
|
return F.fill(data_dtype, data_shape, value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_ellipsis_with_tensor(data, index, value): |
|
|
|
|
|
|
|
|
def tensor_setitem_by_ellipsis_with_tensor(data, value): |
|
|
"""Assigns the tensor by ellipsis with tensor value.""" |
|
|
"""Assigns the tensor by ellipsis with tensor value.""" |
|
|
result = None |
|
|
|
|
|
data_shape = F.shape(data) |
|
|
data_shape = F.shape(data) |
|
|
data_dtype = F.dtype(data) |
|
|
data_dtype = F.dtype(data) |
|
|
data_size = F.size(data) |
|
|
|
|
|
|
|
|
value = value.astype(data_dtype) |
|
|
value_shape = F.shape(value) |
|
|
value_shape = F.shape(value) |
|
|
value_size = F.size(value) |
|
|
|
|
|
check_result = const_utils.check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size) |
|
|
|
|
|
if check_result: |
|
|
|
|
|
if data_size == value_size: |
|
|
|
|
|
result = F.reshape(value, data_shape) |
|
|
|
|
|
result = F.cast(result, data_dtype) |
|
|
|
|
|
elif value_size == 1: |
|
|
|
|
|
param1 = F.fill(data_dtype, data_shape, 1) |
|
|
|
|
|
param2 = F.cast(value, data_dtype) |
|
|
|
|
|
result = F.tensor_mul(param1, param2) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
source_shape = const_utils.get_source_shape(data_shape, value_shape) |
|
|
|
|
|
value = F.reshape(value, source_shape) |
|
|
|
|
|
value = _broadcast(data_shape, value) |
|
|
|
|
|
data = F.cast(value, data_dtype) |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_ellipsis_with_sequence(data, value): |
|
|
|
|
|
"""Assigns a list/tuple value to the tensor by ellipsis.""" |
|
|
|
|
|
value = _generate_updates_from_sequence(data, None, value, const_utils.SET_ITEM_BY_NON_TENSOR) |
|
|
|
|
|
return tensor_setitem_by_ellipsis_with_tensor(data, value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_bool(data, index, value): |
|
|
|
|
|
"""Assigns a value to the tensor by boolean.""" |
|
|
|
|
|
data_shape = F.shape(data) |
|
|
|
|
|
if not index: |
|
|
|
|
|
data_shape = (0,) + data_shape |
|
|
|
|
|
if not isinstance(value, Tensor): |
|
|
|
|
|
value = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_NON_TENSOR) |
|
|
|
|
|
value_shape = F.shape(value) |
|
|
|
|
|
source_shape = const_utils.get_source_shape(data_shape, value_shape) |
|
|
|
|
|
if index: |
|
|
|
|
|
value = F.reshape(value, source_shape) |
|
|
|
|
|
value = _broadcast(data_shape, value) |
|
|
|
|
|
data = value |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_in_sequence(x, y): |
|
|
def tensor_in_sequence(x, y): |
|
|
@@ -675,3 +744,79 @@ def tensor_in_sequence(x, y): |
|
|
if isinstance(i, Tensor) and x.shape == i.shape and x.dtype == i.dtype: |
|
|
if isinstance(i, Tensor) and x.shape == i.shape and x.dtype == i.dtype: |
|
|
result = F.logical_or(F.equal(x, i).all(), result) |
|
|
result = F.logical_or(F.equal(x, i).all(), result) |
|
|
return result |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_list_indices(list_indices, length): |
|
|
|
|
|
"""Convert list indices to tensor or tuple indices based on its contents.""" |
|
|
|
|
|
indices_types = hyper_map(F.typeof, list_indices) |
|
|
|
|
|
# If eyery element in list is bool, it's treated as 1-D bool tensor. |
|
|
|
|
|
# If every element in list is int(not all bool), it's treated as int tensor. |
|
|
|
|
|
if const_utils.judge_indexes_types(indices_types, mstype.int_type+(mstype.bool_,)): |
|
|
|
|
|
list_indices = const_utils.transform_sequence_index(list_indices, length, const_utils.TENSOR_SETITEM) |
|
|
|
|
|
return const_utils.make_tensor(list_indices) |
|
|
|
|
|
# If list contains other types(.../list/tuple/None), it's treated as a tuple |
|
|
|
|
|
return const_utils.deep_tuple(list_indices) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_tuple_indices(tuple_indices): |
|
|
|
|
|
""" |
|
|
|
|
|
Format tuple indices by unpacking high-dimension tuple and removing expand |
|
|
|
|
|
dimension signs(Bool and None). |
|
|
|
|
|
""" |
|
|
|
|
|
res = () |
|
|
|
|
|
for i in tuple_indices: |
|
|
|
|
|
if isinstance(i, (list, tuple)): |
|
|
|
|
|
res += (const_utils.unpack(i),) |
|
|
|
|
|
else: |
|
|
|
|
|
res += (i,) |
|
|
|
|
|
return res |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tuple_indices_have_false(tuple_indices): |
|
|
|
|
|
"""Returns True if tuple_indices contains False.""" |
|
|
|
|
|
for i in tuple_indices: |
|
|
|
|
|
if i is False: |
|
|
|
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ignore_dim_expand(idx): |
|
|
|
|
|
"""Filters flags for dimension expansion from idx.""" |
|
|
|
|
|
res = () |
|
|
|
|
|
for i in idx: |
|
|
|
|
|
if not i is True and not i is None: |
|
|
|
|
|
res += (i,) |
|
|
|
|
|
if not res: |
|
|
|
|
|
res = (True,) |
|
|
|
|
|
return res |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_ignored_dim(idx, value_shape, data_rank): |
|
|
|
|
|
"""Removes dimensions in value that correspond to dimension expansion flags in index.""" |
|
|
|
|
|
has_ellipsis = False |
|
|
|
|
|
has_true = False |
|
|
|
|
|
cnt_trailing_expanded = 0 |
|
|
|
|
|
cnt_not_dim_expand = 0 |
|
|
|
|
|
for i in idx: |
|
|
|
|
|
if not i is True and not i is None: |
|
|
|
|
|
cnt_not_dim_expand += 1 |
|
|
|
|
|
if const_utils.is_ellipsis(i): |
|
|
|
|
|
has_ellipsis = True |
|
|
|
|
|
elif has_ellipsis: |
|
|
|
|
|
if i is None: |
|
|
|
|
|
cnt_trailing_expanded += 1 |
|
|
|
|
|
elif i is True and not has_true: |
|
|
|
|
|
has_true = True |
|
|
|
|
|
if has_true and cnt_not_dim_expand + 1 < data_rank: |
|
|
|
|
|
cnt_trailing_expanded += 1 |
|
|
|
|
|
|
|
|
|
|
|
if cnt_trailing_expanded == 0: |
|
|
|
|
|
return value_shape |
|
|
|
|
|
value_expanded_pos = len(value_shape) - cnt_trailing_expanded |
|
|
|
|
|
value_expanded_not_unit = False |
|
|
|
|
|
for i in value_shape[value_expanded_pos:]: |
|
|
|
|
|
if i != 1: |
|
|
|
|
|
value_expanded_not_unit = True |
|
|
|
|
|
if value_expanded_pos < 0 or value_expanded_not_unit: |
|
|
|
|
|
const_utils.raise_value_error('shape mismatch') |
|
|
|
|
|
return value_shape[:value_expanded_pos] |