|
|
|
@@ -42,6 +42,27 @@ def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x): |
|
|
|
return _broadcast(final_shape, F.reshape(x, new_shape)) |
|
|
|
|
|
|
|
|
|
|
|
def _transform_ellipsis_to_slice(tuple_index, data, op_name): |
|
|
|
"""transform ellipsis in the slice to several slice""" |
|
|
|
data_shape = F.shape(data) |
|
|
|
data_rank = len(data_shape) |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
slice_positions, ellipsis_positions, _, int_positions, _, tensor_positions, sequence_positions = \ |
|
|
|
const_utils.get_pos_of_indexes_types(indexes_types, op_name) |
|
|
|
ellipsis_occupy_dims = data_rank - (len(slice_positions) + len(int_positions) + |
|
|
|
len(tensor_positions) + len(sequence_positions)) |
|
|
|
|
|
|
|
tuple_index_new = () |
|
|
|
for i, index in enumerate(tuple_index): |
|
|
|
if i in ellipsis_positions: |
|
|
|
for _ in range(ellipsis_occupy_dims): |
|
|
|
empty_slice = const_utils.make_empty_slice() |
|
|
|
tuple_index_new += (empty_slice,) |
|
|
|
else: |
|
|
|
tuple_index_new += (index,) |
|
|
|
return tuple_index_new |
|
|
|
|
|
|
|
|
|
|
|
def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): |
|
|
|
"""Generate an indices tensor from a tuple of tensor.""" |
|
|
|
indices = None |
|
|
|
@@ -64,6 +85,7 @@ def _generate_indices_from_tuple(data, tuple_index, op_name): |
|
|
|
int_positions, sequence_positions = const_utils.get_pos_of_int_sequence(indexes_types) |
|
|
|
tuple_index_new = () |
|
|
|
tuple_len = len(tuple_index) |
|
|
|
|
|
|
|
for i in range(tuple_len): |
|
|
|
index = tuple_index[i] |
|
|
|
shape = data_shape[i] |
|
|
|
@@ -77,15 +99,16 @@ def _generate_indices_from_tuple(data, tuple_index, op_name): |
|
|
|
tuple_index_new += (tensor_index,) |
|
|
|
else: |
|
|
|
tuple_index_new += (index,) |
|
|
|
|
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index_new) |
|
|
|
tensor_positions, slice_positions, ellipsis_position = \ |
|
|
|
const_utils.separate_mixed_tensors_index(indexes_types, op_name) |
|
|
|
tensor_indexes = [] |
|
|
|
slice_indexes = [] |
|
|
|
tensor_indexes, slice_indexes = [], [] |
|
|
|
for i in tensor_positions: |
|
|
|
tensor_indexes.append(tuple_index_new[i]) |
|
|
|
for j in slice_positions: |
|
|
|
slice_indexes.append(tuple_index_new[j]) |
|
|
|
|
|
|
|
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) |
|
|
|
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) |
|
|
|
broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \ |
|
|
|
@@ -320,12 +343,12 @@ def tensor_index_by_tensor(data, tensor_index): |
|
|
|
|
|
|
|
def _tensor_index_by_tuple_slice(data, tuple_index): |
|
|
|
"""Tensor getitem by a tuple of slice""" |
|
|
|
shape = F.shape(data) |
|
|
|
if len(tuple_index) > len(shape): |
|
|
|
const_utils.raise_index_error("When tensor is indexed by a tuple, " |
|
|
|
"the length of the tuple cannot be greater than the dimension of the tensor.") |
|
|
|
data_shape = F.shape(data) |
|
|
|
if len(tuple_index) > len(data_shape): |
|
|
|
const_utils.raise_index_error("When tensor is indexed by a tuple, the length of the tuple cannot " |
|
|
|
"be greater than the dimension of the tensor.") |
|
|
|
begin_strides, end_strides, step_strides, shrink_axis_mask = \ |
|
|
|
const_utils.get_stride_info_from_tuple(shape, tuple_index) |
|
|
|
const_utils.get_stride_info_from_tuple(data_shape, tuple_index) |
|
|
|
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) |
|
|
|
|
|
|
|
|
|
|
|
@@ -340,16 +363,15 @@ def tensor_index_by_list(data, list_index): |
|
|
|
|
|
|
|
def tensor_index_by_tuple(data, tuple_index): |
|
|
|
"""Tensor getitem by tuple of various types with None""" |
|
|
|
tuple_index_without_none = tuple_index |
|
|
|
if len(tuple_index) == 1: |
|
|
|
return data[tuple_index_without_none[0]] |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index_without_none) |
|
|
|
return data[tuple_index[0]] |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_GETITEM) |
|
|
|
if contain_type == const_utils.ALL_TENSOR: |
|
|
|
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) |
|
|
|
if contain_type == const_utils.ALL_BASIC: |
|
|
|
return _tensor_index_by_tuple_slice(data, tuple_index) |
|
|
|
return _tensor_getitem_by_tuple(data, tuple_index_without_none) |
|
|
|
return _tensor_getitem_by_tuple(data, tuple_index) |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_setitem(self, index, value): |
|
|
|
@@ -456,7 +478,7 @@ def tensor_setitem_by_tensor_with_number(data, index, value): |
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_tensor_with_tuple(data, index, value): |
|
|
|
"""Assigns the tensor by tensor with tuple value.""" |
|
|
|
"""Assigns the tensor by tensor with tuple value.""" |
|
|
|
index_dtype = F.dtype(index) |
|
|
|
check_dtype = const_utils.check_index_tensor_dtype(index_dtype, const_utils.TENSOR_SETITEM) |
|
|
|
result = None |
|
|
|
@@ -505,7 +527,7 @@ def tensor_setitem_by_slice_with_number(data, input_slice, value): |
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_tuple_with_number(data, tuple_index, value): |
|
|
|
"""Assigns the tensor by tuple with number value.""" |
|
|
|
"""Assigns the tensor by tuple with number value.""" |
|
|
|
if len(tuple_index) == 1: |
|
|
|
data[tuple_index[0]] = value |
|
|
|
return data |
|
|
|
@@ -570,7 +592,7 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value): |
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): |
|
|
|
"""Assigns the tensor by tuple with tensor value.""" |
|
|
|
"""Assigns the tensor by tuple with tensor value.""" |
|
|
|
if len(tuple_index) == 1: |
|
|
|
data[tuple_index[0]] = value |
|
|
|
return data |
|
|
|
@@ -614,7 +636,7 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): |
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): |
|
|
|
"""Assigns the tensor by tuple with tuple of value.""" |
|
|
|
"""Assigns the tensor by tuple with tuple of value.""" |
|
|
|
if len(tuple_index) == 1: |
|
|
|
data[tuple_index[0]] = value |
|
|
|
return data |
|
|
|
@@ -640,28 +662,28 @@ def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): |
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_number_with_number(data, index, value): |
|
|
|
"""Assigns the tensor by number with number value.""" |
|
|
|
"""Assigns the tensor by number with number value.""" |
|
|
|
data_shape = F.shape(data) |
|
|
|
indices = const_utils.integer_to_indices(index, data_shape) |
|
|
|
return _tensor_indices_number(data, data_shape, index, indices, value) |
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_number_with_tensor(data, index, value): |
|
|
|
"""Assigns the tensor by number with tensor value.""" |
|
|
|
"""Assigns the tensor by number with tensor value.""" |
|
|
|
data_shape = F.shape(data) |
|
|
|
indices = const_utils.integer_to_indices(index, data_shape) |
|
|
|
return _tensor_indices_tensor(data, data_shape, index, indices, value) |
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_ellipsis_with_number(data, index, value): |
|
|
|
"""Assigns the tensor by ellipsis with number value.""" |
|
|
|
"""Assigns the tensor by ellipsis with number value.""" |
|
|
|
data_shape = F.shape(data) |
|
|
|
data_dtype = F.dtype(data) |
|
|
|
return F.fill(data_dtype, data_shape, value) |
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_ellipsis_with_tensor(data, index, value): |
|
|
|
"""Assigns the tensor by ellipsis with tensor value.""" |
|
|
|
"""Assigns the tensor by ellipsis with tensor value.""" |
|
|
|
result = None |
|
|
|
data_shape = F.shape(data) |
|
|
|
data_dtype = F.dtype(data) |
|
|
|
|