From e7a3f68d29f9fdacdaa186082a08fce274caea5a Mon Sep 17 00:00:00 2001 From: yepei6 Date: Sun, 28 Feb 2021 15:47:09 +0800 Subject: [PATCH] expand dims of scalar Tensor index --- mindspore/ops/composite/multitype_ops/_compile_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index 85925d63b9..f005c8e7bd 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -510,6 +510,8 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): """Set a tensor item by a int tensor with a scalar.""" + if not F.shape(index): + index = F.expand_dims(index, 0) updates = _generate_updates_from_scalar(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) index = F.expand_dims(index, -1) return P.TensorScatterUpdate()(data, index, updates)