diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index a86e3cae37..337fb7ba8a 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -423,10 +423,9 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): """Set a tensor item by a int tensor with a scalar.""" - updates = _generate_updates_from_scalar(data, index, value, - const_utils.SET_ITEM_BY_ONE_TENSOR) - index = F.expand_dims(index, -1) - return P.TensorScatterUpdate()(data, index, updates) + index = F.expand_dims(index, 0) + updates = _generate_updates_from_scalar(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) + return P.ScatterUpdate()(data, index, updates) def tensor_setitem_by_tensor_with_number(data, index, value):