Browse Source

!12068 getitem with empty list

From: @yepei6
Reviewed-by: @kisnwang,@kingxian
Signed-off-by: @kingxian
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
a99b3cd8ca
1 changed files with 22 additions and 6 deletions
  1. +22
    -6
      mindspore/ops/composite/multitype_ops/_compile_utils.py

+ 22
- 6
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -229,15 +229,31 @@ def tensor_index_by_tuple(data, tuple_index):
indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
if contain_type == const_utils.ALL_TENSOR:
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name)
if contain_type == const_utils.ALL_BASIC:
return _tensor_getitem_by_tuple_slice(data, tuple_index)
return _tensor_getitem_by_tuple(data, tuple_index)
return _tensor_getitem_by_tuple(data, tuple_index, op_name)


def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name):
"""Tensor getitem by a tuple of tensor."""
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_GETITEM)
data_shape = F.shape(data)
tuple_index_len = len(tuple_index)

indexes_types = hyper_map(F.dtype, tuple_index)
const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name)
tensor_index_shape = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
if 0 in broadcast_shape:
res_shape = broadcast_shape + data_shape[tuple_index_len:]
res = const_utils.make_tensor([], data.dtype, res_shape)
return res

broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
new_broadcast_tensors = ()
for tensor in broadcast_tensors:
new_broadcast_tensors += (F.cast(tensor, mstype.int64),)
indices = pack(new_broadcast_tensors)
result = F.gather_nd(data, indices)
return result

@@ -250,9 +266,9 @@ def _tensor_getitem_by_tuple_slice(data, tuple_index):
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)


def _tensor_getitem_by_tuple(data, tuple_index):
def _tensor_getitem_by_tuple(data, tuple_index, op_name):
"""Tensor getitem by a tuple of mixed tensor."""
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_GETITEM)
indices = _generate_indices_from_tuple(data, tuple_index, op_name)
result = F.gather_nd(data, indices)
return result



Loading…
Cancel
Save