From 10220b364a5ef61433a56988df4c22dd354a82fd Mon Sep 17 00:00:00 2001 From: yepei6 Date: Wed, 3 Feb 2021 20:21:04 +0800 Subject: [PATCH] add the logical switch for the index witch contain '0' --- .../composite/multitype_ops/_compile_utils.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index a86e3cae37..9d0f42ad85 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -229,15 +229,31 @@ def tensor_index_by_tuple(data, tuple_index): indexes_types = hyper_map(F.typeof, tuple_index) contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name) if contain_type == const_utils.ALL_TENSOR: - return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) + return _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name) if contain_type == const_utils.ALL_BASIC: return _tensor_getitem_by_tuple_slice(data, tuple_index) - return _tensor_getitem_by_tuple(data, tuple_index) + return _tensor_getitem_by_tuple(data, tuple_index, op_name) -def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): +def _tensor_getitem_by_tuple_of_tensor(data, tuple_index, op_name): """Tensor getitem by a tuple of tensor.""" - indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_GETITEM) + data_shape = F.shape(data) + tuple_index_len = len(tuple_index) + + indexes_types = hyper_map(F.dtype, tuple_index) + const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name) + tensor_index_shape = hyper_map(F.shape, tuple_index) + broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name) + if 0 in broadcast_shape: + res_shape = broadcast_shape + data_shape[tuple_index_len:] + res = const_utils.make_tensor([], data.dtype, res_shape) + return res + + broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index) + new_broadcast_tensors = () + for tensor in broadcast_tensors: + new_broadcast_tensors += (F.cast(tensor, mstype.int64),) + indices = pack(new_broadcast_tensors) result = F.gather_nd(data, indices) return result @@ -250,9 +266,9 @@ def _tensor_getitem_by_tuple_slice(data, tuple_index): return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) -def _tensor_getitem_by_tuple(data, tuple_index): +def _tensor_getitem_by_tuple(data, tuple_index, op_name): """Tensor getitem by a tuple of mixed tensor.""" - indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_GETITEM) + indices = _generate_indices_from_tuple(data, tuple_index, op_name) result = F.gather_nd(data, indices) return result