diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index 85925d63b9..f005c8e7bd 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -510,6 +510,8 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): """Set a tensor item by a int tensor with a scalar.""" + if not F.shape(index): + index = F.expand_dims(index, 0) updates = _generate_updates_from_scalar(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) index = F.expand_dims(index, -1) return P.TensorScatterUpdate()(data, index, updates)