|
|
|
@@ -59,13 +59,15 @@ def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): |
|
|
|
|
|
|
|
def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): |
|
|
|
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor.""" |
|
|
|
data_shape = F.shape(data) |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
int_positions = const_utils.get_pos_of_int_index(indexes_types) |
|
|
|
tuple_index_new = () |
|
|
|
tuple_len = len(tuple_index) |
|
|
|
for i in range(tuple_len): |
|
|
|
if i in int_positions: |
|
|
|
tuple_index_new += (F.scalar_to_tensor(tuple_index[i], mstype.int32),) |
|
|
|
tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] + \ |
|
|
|
data_shape[i], mstype.int32),) |
|
|
|
else: |
|
|
|
tuple_index_new += (tuple_index[i],) |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index_new) |
|
|
|
@@ -77,7 +79,6 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): |
|
|
|
tensor_indexes.append(tuple_index_new[i]) |
|
|
|
for j in slice_positions: |
|
|
|
slice_indexes.append(tuple_index_new[j]) |
|
|
|
data_shape = F.shape(data) |
|
|
|
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) |
|
|
|
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) |
|
|
|
broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \ |
|
|
|
|