|
|
|
@@ -229,15 +229,31 @@ def tensor_index_by_tuple(data, tuple_index): |
|
|
|
indexes_types = hyper_map(F.typeof, tuple_index) |
|
|
|
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name) |
|
|
|
if contain_type == const_utils.ALL_TENSOR: |
|
|
|
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) |
|
|
|
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name) |
|
|
|
if contain_type == const_utils.ALL_BASIC: |
|
|
|
return _tensor_getitem_by_tuple_slice(data, tuple_index) |
|
|
|
return _tensor_getitem_by_tuple(data, tuple_index) |
|
|
|
return _tensor_getitem_by_tuple(data, tuple_index, op_name) |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): |
|
|
|
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name): |
|
|
|
"""Tensor getitem by a tuple of tensor.""" |
|
|
|
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_GETITEM) |
|
|
|
data_shape = F.shape(data) |
|
|
|
tuple_index_len = len(tuple_index) |
|
|
|
|
|
|
|
indexes_types = hyper_map(F.dtype, tuple_index) |
|
|
|
const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name) |
|
|
|
tensor_index_shape = hyper_map(F.shape, tuple_index) |
|
|
|
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name) |
|
|
|
if 0 in broadcast_shape: |
|
|
|
res_shape = broadcast_shape + data_shape[tuple_index_len:] |
|
|
|
res = const_utils.make_tensor([], data.dtype, res_shape) |
|
|
|
return res |
|
|
|
|
|
|
|
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index) |
|
|
|
new_broadcast_tensors = () |
|
|
|
for tensor in broadcast_tensors: |
|
|
|
new_broadcast_tensors += (F.cast(tensor, mstype.int64),) |
|
|
|
indices = pack(new_broadcast_tensors) |
|
|
|
result = F.gather_nd(data, indices) |
|
|
|
return result |
|
|
|
|
|
|
|
@@ -250,9 +266,9 @@ def _tensor_getitem_by_tuple_slice(data, tuple_index): |
|
|
|
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_getitem_by_tuple(data, tuple_index): |
|
|
|
def _tensor_getitem_by_tuple(data, tuple_index, op_name): |
|
|
|
"""Tensor getitem by a tuple of mixed tensor.""" |
|
|
|
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_GETITEM) |
|
|
|
indices = _generate_indices_from_tuple(data, tuple_index, op_name) |
|
|
|
result = F.gather_nd(data, indices) |
|
|
|
return result |
|
|
|
|
|
|
|
|