|
|
|
@@ -26,6 +26,66 @@ hyper_map = base.HyperMap() |
|
|
|
pack = P.Pack(axis=-1) |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_getitem(self, index): |
|
|
|
"""Handle tensor getitem""" |
|
|
|
if isinstance(index, Tensor): |
|
|
|
return tensor_index_by_tensor(self, index) |
|
|
|
if isinstance(index, list): |
|
|
|
return tensor_index_by_list(self, index) |
|
|
|
if isinstance(index, tuple): |
|
|
|
return tensor_index_by_tuple(self, index) |
|
|
|
# bool type should be judged before int |
|
|
|
if isinstance(index, bool): |
|
|
|
return _tensor_index_by_bool(self, index) |
|
|
|
if isinstance(index, int): |
|
|
|
return _tensor_index_by_integer(self, index) |
|
|
|
if isinstance(index, slice): |
|
|
|
return tensor_index_by_slice(self, index) |
|
|
|
if index is None: |
|
|
|
return F.expand_dims(self, 0) |
|
|
|
if index is ...: |
|
|
|
return self |
|
|
|
raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor with int, " |
|
|
|
f"list and tuple ,but got {index} with type {type(index)}.") |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_setitem(self, index, value): |
|
|
|
"""Handle tensor getitem""" |
|
|
|
if isinstance(index, Tensor): |
|
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
return tensor_setitem_by_tensor_with_number(self, index, value) |
|
|
|
if isinstance(value, Tensor): |
|
|
|
return tensor_setitem_by_tensor_with_tensor(self, index, value) |
|
|
|
if isinstance(value, tuple): |
|
|
|
return tensor_setitem_by_tensor_with_tuple(self, index, value) |
|
|
|
if isinstance(index, tuple): |
|
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
return tensor_setitem_by_tuple_with_number(self, index, value) |
|
|
|
if isinstance(value, Tensor): |
|
|
|
return tensor_setitem_by_tuple_with_tensor(self, index, value) |
|
|
|
if isinstance(value, tuple): |
|
|
|
return tensor_setitem_by_tuple_with_tuple(self, index, value) |
|
|
|
if isinstance(index, int): |
|
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
return tensor_setitem_by_number_with_number(self, index, value) |
|
|
|
if isinstance(value, Tensor): |
|
|
|
return tensor_setitem_by_number_with_tensor(self, index, value) |
|
|
|
if isinstance(index, slice): |
|
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
return tensor_setitem_by_slice_with_number(self, index, value) |
|
|
|
if isinstance(value, Tensor): |
|
|
|
return tensor_setitem_by_slice_with_tensor(self, index, value) |
|
|
|
if isinstance(index, bool): |
|
|
|
return _tensor_index_by_bool(self, index) |
|
|
|
if index is ...: |
|
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
return tensor_setitem_by_ellipsis_with_number(self, index, value) |
|
|
|
if isinstance(value, Tensor): |
|
|
|
return tensor_setitem_by_ellipsis_with_tensor(self, index, value) |
|
|
|
raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\ |
|
|
|
and tensor with int32, got {} with type{}".format(index, type(index))) |
|
|
|
|
|
|
|
|
|
|
|
def _broadcast(broadcast_shape, x): |
|
|
|
"""Broadcast tensor to the required shape.""" |
|
|
|
if F.shape(x) == broadcast_shape: |
|
|
|
@@ -42,15 +102,21 @@ def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x): |
|
|
|
return _broadcast(final_shape, F.reshape(x, new_shape)) |
|
|
|
|
|
|
|
|
|
|
|
def _transform_ellipsis_to_slice(tuple_index, data, op_name): |
|
|
|
"""transform ellipsis in the slice to several slice""" |
|
|
|
def _transform_ellipsis_to_slice(data, tuple_index, op_name): |
|
|
|
"""Check if the tuple index len is longer than the data's dims and transform ellipsis in the indices |
|
|
|
to several slice""" |
|
|
|
data_shape = F.shape(data) |
|
|
|
data_rank = len(data_shape) |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
slice_positions, ellipsis_positions, _, int_positions, _, tensor_positions, sequence_positions = \ |
|
|
|
const_utils.get_pos_of_indexes_types(indexes_types, op_name) |
|
|
|
|
|
|
|
ellipsis_occupy_dims = data_rank - (len(slice_positions) + len(int_positions) + |
|
|
|
len(tensor_positions) + len(sequence_positions)) |
|
|
|
ellipsis_cnt = len(ellipsis_positions) |
|
|
|
if (ellipsis_cnt == 0 and ellipsis_occupy_dims < 0) or (ellipsis_cnt > 0 and ellipsis_occupy_dims < 1): |
|
|
|
const_utils.raise_index_error("For the 'getitem Operator', the data_shape should be no less than the " |
|
|
|
"tuple index dims") |
|
|
|
|
|
|
|
tuple_index_new = () |
|
|
|
for i, index in enumerate(tuple_index): |
|
|
|
@@ -63,122 +129,172 @@ def _transform_ellipsis_to_slice(tuple_index, data, op_name): |
|
|
|
return tuple_index_new |
|
|
|
|
|
|
|
|
|
|
|
def _expand_data_dims_with_none(data, tuple_index, op_name): |
|
|
|
"""expand the data's dim with 'None' in tuple_index""" |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
none_positions, tuple_index_without_none = (), () |
|
|
|
for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)): |
|
|
|
none_type_tag = const_utils.judge_index_type(index_type, mstype.type_none) |
|
|
|
tuple_index_without_none += (const_utils.make_empty_slice(),) if none_type_tag else(index,) |
|
|
|
none_positions += (i,) if none_type_tag else () |
|
|
|
|
|
|
|
for dim in none_positions: |
|
|
|
data = F.expand_dims(data, dim) |
|
|
|
|
|
|
|
return data, tuple_index_without_none |
|
|
|
|
|
|
|
|
|
|
|
def tensor_index_by_slice(data, slice_index): |
|
|
|
"""Tensor getitem by a single slice""" |
|
|
|
shape = F.shape(data) |
|
|
|
if not shape: |
|
|
|
const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor" |
|
|
|
"cannot be 0.") |
|
|
|
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(shape, slice_index) |
|
|
|
return F.strided_slice(data, begin_strides, end_strides, step_strides) |
|
|
|
|
|
|
|
|
|
|
|
def tensor_index_by_number(data, number): |
|
|
|
"""Tensor getitem by a Number which may be integer/float/bool value""" |
|
|
|
number_type = const_utils.check_number_index_type(number) |
|
|
|
if number_type == const_utils.BOOL_: |
|
|
|
return _tensor_index_by_bool(data, number) |
|
|
|
if number_type == const_utils.INT_: |
|
|
|
return _tensor_index_by_integer(data, number) |
|
|
|
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.") |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_index_by_bool(data, bool_value): |
|
|
|
"""Tensor getitem by a single bool value""" |
|
|
|
if bool_value: |
|
|
|
return F.expand_dims(data, 0) |
|
|
|
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.") |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_index_by_integer(data, number): |
|
|
|
"""Tensor getitem by a single integer number""" |
|
|
|
data_shape = F.shape(data) |
|
|
|
data_rank = len(data_shape) |
|
|
|
if data_rank == 0: |
|
|
|
return const_utils.raise_type_error("When tensor is indexed by an integer, the dimension of the tensor " |
|
|
|
"cannot be 0.") |
|
|
|
transformed_number = const_utils.check_and_transform_int_index(number, data_shape[0], const_utils.TENSOR_GETITEM) |
|
|
|
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(data_shape, transformed_number) |
|
|
|
shrink_axis_mask = 1 |
|
|
|
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) |
|
|
|
|
|
|
|
|
|
|
|
def tensor_index_by_tensor(data, tensor_index): |
|
|
|
"""Tensor getitem by a single tensor""" |
|
|
|
index_type = F.dtype(tensor_index) |
|
|
|
const_utils.check_index_type_valid(index_type, mstype.int_type, const_utils.TENSOR_GETITEM) |
|
|
|
tensor_index = F.cast(tensor_index, mstype.int64) |
|
|
|
return F.gather(data, tensor_index, 0) |
|
|
|
|
|
|
|
|
|
|
|
def tensor_index_by_list(data, list_index): |
|
|
|
"""Tensor getitem by list of int and bool""" |
|
|
|
data_shape = F.shape(data) |
|
|
|
const_utils.check_sequence_index_type(list_index, const_utils.TENSOR_GETITEM) |
|
|
|
sub_tuple_index = const_utils.transform_sequence_index(list_index, data_shape[0], const_utils.TENSOR_GETITEM) |
|
|
|
tensor_index = F.tuple_to_array(sub_tuple_index) |
|
|
|
tensor_index = F.cast(tensor_index, mstype.int64) |
|
|
|
return F.gather(data, tensor_index, 0) |
|
|
|
|
|
|
|
|
|
|
|
def tensor_index_by_tuple(data, tuple_index): |
|
|
|
"""Tensor getitem by tuple of various types with None""" |
|
|
|
op_name = const_utils.TENSOR_GETITEM |
|
|
|
if len(tuple_index) == 1: |
|
|
|
return data[tuple_index[0]] |
|
|
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) |
|
|
|
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name) |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name) |
|
|
|
if contain_type == const_utils.ALL_TENSOR: |
|
|
|
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) |
|
|
|
if contain_type == const_utils.ALL_BASIC: |
|
|
|
return _tensor_getitem_by_tuple_slice(data, tuple_index) |
|
|
|
return _tensor_getitem_by_tuple(data, tuple_index) |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): |
|
|
|
"""Tensor getitem by a tuple of tensor.""" |
|
|
|
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_GETITEM) |
|
|
|
result = F.gather_nd(data, indices) |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_getitem_by_tuple_slice(data, tuple_index): |
|
|
|
"""Tensor getitem by a tuple of slice""" |
|
|
|
data_shape = F.shape(data) |
|
|
|
begin_strides, end_strides, step_strides, shrink_axis_mask = \ |
|
|
|
const_utils.get_stride_info_from_tuple(data_shape, tuple_index) |
|
|
|
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_getitem_by_tuple(data, tuple_index): |
|
|
|
"""Tensor getitem by a tuple of mixed tensor.""" |
|
|
|
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_GETITEM) |
|
|
|
result = F.gather_nd(data, indices) |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): |
|
|
|
"""Generate an indices tensor from a tuple of tensor.""" |
|
|
|
indices = None |
|
|
|
check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name) |
|
|
|
if check_index_tensor_number: |
|
|
|
dtype_tuple = hyper_map(F.dtype, tuple_index) |
|
|
|
check_dtypes = const_utils.check_index_tensors_dtype(dtype_tuple, op_name) |
|
|
|
if check_dtypes: |
|
|
|
shape_tuple = hyper_map(F.shape, tuple_index) |
|
|
|
broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name) |
|
|
|
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index) |
|
|
|
indices = pack(broadcast_tensors) |
|
|
|
indexes_types = hyper_map(F.dtype, tuple_index) |
|
|
|
const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name) |
|
|
|
tensor_index_shape = hyper_map(F.shape, tuple_index) |
|
|
|
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name) |
|
|
|
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index) |
|
|
|
indices = pack(broadcast_tensors) |
|
|
|
indices = F.cast(indices, mstype.int64) |
|
|
|
return indices |
|
|
|
|
|
|
|
|
|
|
|
def _generate_indices_from_tuple(data, tuple_index, op_name): |
|
|
|
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor.""" |
|
|
|
data_shape = F.shape(data) |
|
|
|
tuple_index_len = len(tuple_index) |
|
|
|
tensor_indexes, slice_indexes = [], [] |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
int_positions, sequence_positions = const_utils.get_pos_of_int_sequence(indexes_types) |
|
|
|
slice_positions, _, _, int_positions, _, \ |
|
|
|
tensor_positions, sequence_positions = const_utils.get_pos_of_indexes_types(indexes_types, op_name) |
|
|
|
tuple_index_new = () |
|
|
|
tuple_len = len(tuple_index) |
|
|
|
|
|
|
|
for i in range(tuple_len): |
|
|
|
index = tuple_index[i] |
|
|
|
shape = data_shape[i] |
|
|
|
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)): |
|
|
|
if i in int_positions: |
|
|
|
int_index = const_utils.check_and_transform_int_index(index, shape, op_name) |
|
|
|
tensor_index = F.scalar_to_tensor(int_index, mstype.int32) |
|
|
|
int_index = const_utils.check_and_transform_int_index(index, dim_size, op_name) |
|
|
|
tensor_index = F.scalar_to_tensor(int_index, mstype.int64) |
|
|
|
tuple_index_new += (tensor_index,) |
|
|
|
tensor_indexes.append(tensor_index) |
|
|
|
tensor_positions.append(i) |
|
|
|
elif i in sequence_positions: |
|
|
|
sequence_index = const_utils.transform_sequence_index(index, shape, op_name) |
|
|
|
sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name) |
|
|
|
tensor_index = F.tuple_to_array(sequence_index) |
|
|
|
tensor_index = F.cast(tensor_index, mstype.int64) |
|
|
|
tuple_index_new += (tensor_index,) |
|
|
|
else: |
|
|
|
tensor_indexes.append(tensor_index) |
|
|
|
tensor_positions.append(i) |
|
|
|
elif i in tensor_positions: |
|
|
|
tensor_index = F.cast(index, mstype.int64) |
|
|
|
tuple_index_new += (tensor_index,) |
|
|
|
tensor_indexes.append(tensor_index) |
|
|
|
elif i in slice_positions: |
|
|
|
slice_indexes.append(index) |
|
|
|
tuple_index_new += (index,) |
|
|
|
|
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index_new) |
|
|
|
tensor_positions, slice_positions, ellipsis_position = \ |
|
|
|
const_utils.separate_mixed_tensors_index(indexes_types, op_name) |
|
|
|
tensor_indexes, slice_indexes = [], [] |
|
|
|
for i in tensor_positions: |
|
|
|
tensor_indexes.append(tuple_index_new[i]) |
|
|
|
for j in slice_positions: |
|
|
|
slice_indexes.append(tuple_index_new[j]) |
|
|
|
|
|
|
|
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) |
|
|
|
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) |
|
|
|
broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \ |
|
|
|
const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape, |
|
|
|
indexes_types, |
|
|
|
tensor_indexes_shapes, |
|
|
|
tensor_indexes_dtypes, |
|
|
|
slice_indexes, |
|
|
|
op_name) |
|
|
|
|
|
|
|
slice_number = 0 |
|
|
|
final_index_tensors = [] |
|
|
|
tuple_index_size = len(tuple_index_new) |
|
|
|
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) |
|
|
|
for i in range(tuple_index_size): |
|
|
|
if i in tensor_positions: |
|
|
|
transform_tensor = _transform_indexing_tensor( |
|
|
|
broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i]) |
|
|
|
final_index_tensors.append(transform_tensor) |
|
|
|
if i in slice_positions: |
|
|
|
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name) |
|
|
|
final_index_tensors.append(slice_tensor) |
|
|
|
slice_number += 1 |
|
|
|
if i == ellipsis_position: |
|
|
|
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors( |
|
|
|
slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name) |
|
|
|
for ele in ellipsis_tensors: |
|
|
|
final_index_tensors.append(ele) |
|
|
|
slice_number += ellipsis_occupied_dims |
|
|
|
indices = pack(final_index_tensors) |
|
|
|
return indices |
|
|
|
|
|
|
|
|
|
|
|
def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): |
|
|
|
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor.""" |
|
|
|
data_shape = F.shape(data) |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
int_positions = const_utils.get_pos_of_int_index(indexes_types) |
|
|
|
tuple_index_new = () |
|
|
|
tuple_len = len(tuple_index) |
|
|
|
for i in range(tuple_len): |
|
|
|
if i in int_positions: |
|
|
|
tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] + |
|
|
|
data_shape[i], mstype.int32),) |
|
|
|
else: |
|
|
|
tuple_index_new += (tuple_index[i],) |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index_new) |
|
|
|
tensor_positions, slice_positions, ellipsis_position = \ |
|
|
|
const_utils.separate_mixed_tensors_index(indexes_types, op_name) |
|
|
|
tensor_indexes = [] |
|
|
|
slice_indexes = [] |
|
|
|
for i in tensor_positions: |
|
|
|
tensor_indexes.append(tuple_index_new[i]) |
|
|
|
for j in slice_positions: |
|
|
|
slice_indexes.append(tuple_index_new[j]) |
|
|
|
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) |
|
|
|
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) |
|
|
|
broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \ |
|
|
|
const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape, |
|
|
|
indexes_types, |
|
|
|
tensor_indexes_shapes, |
|
|
|
tensor_indexes_dtypes, |
|
|
|
slice_indexes, |
|
|
|
op_name) |
|
|
|
broadcast_shape, final_shape, indexes_shapes_info = const_utils.generate_index_info_from_tuple_of_mixed_tensors( |
|
|
|
data_shape, indexes_types, tensor_indexes_shapes, tensor_indexes_dtypes, slice_indexes, op_name) |
|
|
|
|
|
|
|
slice_number = 0 |
|
|
|
final_index_tensors = [] |
|
|
|
tuple_index_size = len(tuple_index_new) |
|
|
|
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) |
|
|
|
for i in range(tuple_index_size): |
|
|
|
for i in range(tuple_index_len): |
|
|
|
if i in tensor_positions: |
|
|
|
transform_tensor = _transform_indexing_tensor( |
|
|
|
broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i]) |
|
|
|
@@ -187,12 +303,7 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): |
|
|
|
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name) |
|
|
|
final_index_tensors.append(slice_tensor) |
|
|
|
slice_number += 1 |
|
|
|
if i == ellipsis_position: |
|
|
|
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors( |
|
|
|
slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name) |
|
|
|
for ele in ellipsis_tensors: |
|
|
|
final_index_tensors.append(ele) |
|
|
|
slice_number += ellipsis_occupied_dims |
|
|
|
|
|
|
|
indices = pack(final_index_tensors) |
|
|
|
return indices |
|
|
|
|
|
|
|
@@ -239,179 +350,8 @@ def _generate_updates_from_tensor(data, index, value, op_type): |
|
|
|
return value |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_getitem(self, index): |
|
|
|
"""Handle tensor getitem""" |
|
|
|
if isinstance(index, Tensor): |
|
|
|
return tensor_index_by_tensor(self, index) |
|
|
|
if isinstance(index, tuple): |
|
|
|
return tensor_index_by_tuple(self, index) |
|
|
|
if isinstance(index, list): |
|
|
|
return tensor_index_by_list(self, index) |
|
|
|
# bool type should be judged before int |
|
|
|
if isinstance(index, bool): |
|
|
|
return _tensor_index_by_bool(self, index) |
|
|
|
if isinstance(index, int): |
|
|
|
return _tensor_index_by_integer(self, index) |
|
|
|
if isinstance(index, slice): |
|
|
|
return tensor_index_by_slice(self, index) |
|
|
|
if index is None: |
|
|
|
return F.expand_dims(self, 0) |
|
|
|
if index is ...: |
|
|
|
return self |
|
|
|
raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32, " |
|
|
|
f"got {index} with type {type(index)}.") |
|
|
|
|
|
|
|
|
|
|
|
tensor_operator_registry.register("__getitem__", _tensor_getitem) |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): |
|
|
|
"""Tensor getitem by a tuple of tensor.""" |
|
|
|
indices = _generate_indices_from_tuple_of_tensor(data, |
|
|
|
tuple_index, |
|
|
|
const_utils.TENSOR_GETITEM) |
|
|
|
result = F.gather_nd(data, indices) |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_getitem_by_tuple(data, tuple_index): |
|
|
|
"""Tensor getitem by a tuple of mixed tensor.""" |
|
|
|
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_GETITEM) |
|
|
|
result = F.gather_nd(data, indices) |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): |
|
|
|
"""Tensor getitem by a tuple of mixed tensor.""" |
|
|
|
indices = _generate_indices_from_tuple_of_mixed_tensors(data, |
|
|
|
tuple_index, |
|
|
|
const_utils.TENSOR_GETITEM) |
|
|
|
result = F.gather_nd(data, indices) |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def tensor_index_by_slice(data, slice_index): |
|
|
|
"""Tensor getitem by a single slice""" |
|
|
|
shape = F.shape(data) |
|
|
|
if not shape: |
|
|
|
const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor" |
|
|
|
"cannot be 0.") |
|
|
|
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(shape, slice_index) |
|
|
|
return F.strided_slice(data, begin_strides, end_strides, step_strides) |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_index_by_integer(data, number): |
|
|
|
"""Tensor getitem by a single integer number""" |
|
|
|
shape = F.shape(data) |
|
|
|
if not shape: |
|
|
|
return const_utils.raise_type_error("When tensor is indexed by an integer," |
|
|
|
"the dimension of the tensor cannot be 0.") |
|
|
|
if number >= shape[0]: |
|
|
|
return const_utils.raise_index_error("index {} is out of bounds for axis 0 with size {}".format( |
|
|
|
number, shape[0])) |
|
|
|
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(shape, number) |
|
|
|
shrink_axis_mask = 1 |
|
|
|
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_index_by_bool(data, bool_value): |
|
|
|
"""Tensor getitem by a single bool value""" |
|
|
|
if bool_value: |
|
|
|
return F.expand_dims(data, 0) |
|
|
|
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.") |
|
|
|
|
|
|
|
|
|
|
|
def tensor_index_by_number(data, number): |
|
|
|
"""Tensor getitem by a Number which may be integer/float/bool value""" |
|
|
|
number_type = const_utils.check_number_index_type(number) |
|
|
|
if number_type == const_utils.BOOL_: |
|
|
|
return _tensor_index_by_bool(data, number) |
|
|
|
if number_type == const_utils.INT_: |
|
|
|
return _tensor_index_by_integer(data, number) |
|
|
|
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.") |
|
|
|
|
|
|
|
|
|
|
|
def tensor_index_by_tensor(data, tensor_index): |
|
|
|
"""Tensor getitem by a single tensor""" |
|
|
|
dtype_valid = const_utils.check_index_tensor_dtype(F.dtype(tensor_index), |
|
|
|
const_utils.TENSOR_GETITEM) |
|
|
|
if dtype_valid: |
|
|
|
return F.gather(data, tensor_index, 0) |
|
|
|
return const_utils.raise_index_error("For 'tensor getitem', " |
|
|
|
"the index tensor data type only support mstype.int32.") |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_index_by_tuple_slice(data, tuple_index): |
|
|
|
"""Tensor getitem by a tuple of slice""" |
|
|
|
data_shape = F.shape(data) |
|
|
|
if len(tuple_index) > len(data_shape): |
|
|
|
const_utils.raise_index_error("When tensor is indexed by a tuple, the length of the tuple cannot " |
|
|
|
"be greater than the dimension of the tensor.") |
|
|
|
begin_strides, end_strides, step_strides, shrink_axis_mask = \ |
|
|
|
const_utils.get_stride_info_from_tuple(data_shape, tuple_index) |
|
|
|
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) |
|
|
|
|
|
|
|
|
|
|
|
def tensor_index_by_list(data, list_index): |
|
|
|
"""Tensor getitem by list of int and bool""" |
|
|
|
data_shape = F.shape(data) |
|
|
|
const_utils.check_list_index_type(list_index) |
|
|
|
list_index = const_utils.transform_list(list_index, data_shape[0]) |
|
|
|
tensor_index = const_utils.convert_list_to_tensor(list_index) |
|
|
|
return F.gather(data, tensor_index, 0) |
|
|
|
|
|
|
|
|
|
|
|
def tensor_index_by_tuple(data, tuple_index): |
|
|
|
"""Tensor getitem by tuple of various types with None""" |
|
|
|
if len(tuple_index) == 1: |
|
|
|
return data[tuple_index[0]] |
|
|
|
tuple_index = _transform_ellipsis_to_slice(tuple_index, data, const_utils.TENSOR_GETITEM) |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_GETITEM) |
|
|
|
if contain_type == const_utils.ALL_TENSOR: |
|
|
|
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) |
|
|
|
if contain_type == const_utils.ALL_BASIC: |
|
|
|
return _tensor_index_by_tuple_slice(data, tuple_index) |
|
|
|
return _tensor_getitem_by_tuple(data, tuple_index) |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_setitem(self, index, value): |
|
|
|
"""Handle tensor getitem""" |
|
|
|
if isinstance(index, Tensor): |
|
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
return tensor_setitem_by_tensor_with_number(self, index, value) |
|
|
|
if isinstance(value, Tensor): |
|
|
|
return tensor_setitem_by_tensor_with_tensor(self, index, value) |
|
|
|
if isinstance(value, tuple): |
|
|
|
return tensor_setitem_by_tensor_with_tuple(self, index, value) |
|
|
|
if isinstance(index, tuple): |
|
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
return tensor_setitem_by_tuple_with_number(self, index, value) |
|
|
|
if isinstance(value, Tensor): |
|
|
|
return tensor_setitem_by_tuple_with_tensor(self, index, value) |
|
|
|
if isinstance(value, tuple): |
|
|
|
return tensor_setitem_by_tuple_with_tuple(self, index, value) |
|
|
|
if isinstance(index, int): |
|
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
return tensor_setitem_by_number_with_number(self, index, value) |
|
|
|
if isinstance(value, Tensor): |
|
|
|
return tensor_setitem_by_number_with_tensor(self, index, value) |
|
|
|
if isinstance(index, slice): |
|
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
return tensor_setitem_by_slice_with_number(self, index, value) |
|
|
|
if isinstance(value, Tensor): |
|
|
|
return tensor_setitem_by_slice_with_tensor(self, index, value) |
|
|
|
if isinstance(index, bool): |
|
|
|
return _tensor_index_by_bool(self, index) |
|
|
|
if index is ...: |
|
|
|
if isinstance(value, (int, float, bool)): |
|
|
|
return tensor_setitem_by_ellipsis_with_number(self, index, value) |
|
|
|
if isinstance(value, Tensor): |
|
|
|
return tensor_setitem_by_ellipsis_with_tensor(self, index, value) |
|
|
|
raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\ |
|
|
|
and tensor with int32, got {} with type{}".format(index, type(index))) |
|
|
|
|
|
|
|
|
|
|
|
tensor_operator_registry.register("__setitem__", _tensor_setitem) |
|
|
|
|
|
|
|
|
|
|
|
@@ -532,24 +472,21 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value): |
|
|
|
if len(tuple_index) == 1: |
|
|
|
data[tuple_index[0]] = value |
|
|
|
return data |
|
|
|
op_name = const_utils.TENSOR_GETITEM |
|
|
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) |
|
|
|
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name) |
|
|
|
|
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
|
|
|
|
if index_elements_type == const_utils.ALL_TENSOR: |
|
|
|
indices = _generate_indices_from_tuple_of_tensor(data, |
|
|
|
tuple_index, |
|
|
|
const_utils.TENSOR_SETITEM) |
|
|
|
if contain_type == const_utils.ALL_TENSOR: |
|
|
|
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM) |
|
|
|
else: |
|
|
|
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
if int_cnt == const_utils.ALL_INT: |
|
|
|
tuple_index = const_utils.convert_int_to_slice(tuple_index) |
|
|
|
indices = _generate_indices_from_tuple_of_mixed_tensors(data, |
|
|
|
tuple_index, |
|
|
|
const_utils.TENSOR_SETITEM) |
|
|
|
updates = _generate_updates_from_scalar(data, |
|
|
|
indices, |
|
|
|
value, |
|
|
|
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) |
|
|
|
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) |
|
|
|
updates = _generate_updates_from_scalar(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) |
|
|
|
return P.TensorScatterUpdate()(data, indices, updates) |
|
|
|
|
|
|
|
|
|
|
|
@@ -597,42 +534,26 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): |
|
|
|
if len(tuple_index) == 1: |
|
|
|
data[tuple_index[0]] = value |
|
|
|
return data |
|
|
|
data_shape = data.shape |
|
|
|
tuple_index_new = () |
|
|
|
for i, index in enumerate(tuple_index): |
|
|
|
if isinstance(index, mstype.Int): |
|
|
|
if index < -data_shape[i] or index >= data_shape[i]: |
|
|
|
const_utils.raise_index_error("The index is out of the data's special dimension range.") |
|
|
|
elif index < 0: |
|
|
|
tuple_index_new += (tuple_index[i]+data_shape[i],) |
|
|
|
else: |
|
|
|
tuple_index_new += (tuple_index[i],) |
|
|
|
else: |
|
|
|
tuple_index_new += (tuple_index[i],) |
|
|
|
op_name = const_utils.TENSOR_GETITEM |
|
|
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) |
|
|
|
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name) |
|
|
|
|
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index_new) |
|
|
|
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
|
|
|
|
if index_elements_type == const_utils.ALL_TENSOR: |
|
|
|
indices = _generate_indices_from_tuple_of_tensor(data, |
|
|
|
tuple_index_new, |
|
|
|
const_utils.TENSOR_SETITEM) |
|
|
|
if contain_type == const_utils.ALL_TENSOR: |
|
|
|
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM) |
|
|
|
else: |
|
|
|
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
if int_cnt == const_utils.ALL_INT: |
|
|
|
tuple_index_new = const_utils.convert_int_to_slice(tuple_index_new) |
|
|
|
tuple_index = const_utils.convert_int_to_slice(tuple_index) |
|
|
|
new_shape = () |
|
|
|
for _ in tuple_index_new: |
|
|
|
for _ in tuple_index: |
|
|
|
new_shape += (1,) |
|
|
|
new_shape += value.shape |
|
|
|
value = F.reshape(value, new_shape) |
|
|
|
indices = _generate_indices_from_tuple_of_mixed_tensors(data, |
|
|
|
tuple_index_new, |
|
|
|
const_utils.TENSOR_SETITEM) |
|
|
|
updates = _generate_updates_from_tensor(data, |
|
|
|
indices, |
|
|
|
value, |
|
|
|
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) |
|
|
|
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) |
|
|
|
updates = _generate_updates_from_tensor(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) |
|
|
|
return P.TensorScatterUpdate()(data, indices, updates) |
|
|
|
|
|
|
|
|
|
|
|
@@ -641,24 +562,21 @@ def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): |
|
|
|
if len(tuple_index) == 1: |
|
|
|
data[tuple_index[0]] = value |
|
|
|
return data |
|
|
|
op_name = const_utils.TENSOR_GETITEM |
|
|
|
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) |
|
|
|
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name) |
|
|
|
|
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
|
|
|
|
if index_elements_type == const_utils.ALL_TENSOR: |
|
|
|
indices = _generate_indices_from_tuple_of_tensor(data, |
|
|
|
tuple_index, |
|
|
|
const_utils.TENSOR_SETITEM) |
|
|
|
if contain_type == const_utils.ALL_TENSOR: |
|
|
|
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM) |
|
|
|
else: |
|
|
|
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) |
|
|
|
if int_cnt == const_utils.ALL_INT: |
|
|
|
tuple_index = const_utils.convert_int_to_slice(tuple_index) |
|
|
|
indices = _generate_indices_from_tuple_of_mixed_tensors(data, |
|
|
|
tuple_index, |
|
|
|
const_utils.TENSOR_SETITEM) |
|
|
|
updates = _generate_updates_from_tuple(data, |
|
|
|
indices, |
|
|
|
value, |
|
|
|
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) |
|
|
|
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) |
|
|
|
updates = _generate_updates_from_tuple(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) |
|
|
|
return P.TensorScatterUpdate()(data, indices, updates) |
|
|
|
|
|
|
|
|
|
|
|
|