Browse Source

!12716 setitem debug to support 0d scalar tensor

From: @yepei6
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
5ea2dc6e69
1 changed files with 2 additions and 0 deletions
  1. +2
    -0
      mindspore/ops/composite/multitype_ops/_compile_utils.py

+ 2
- 0
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -510,6 +510,8 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value):

def _tensor_setitem_by_int_tensor_with_scalar(data, index, value):
"""Set a tensor item by a int tensor with a scalar."""
if not F.shape(index):
index = F.expand_dims(index, 0)
updates = _generate_updates_from_scalar(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
index = F.expand_dims(index, -1)
return P.TensorScatterUpdate()(data, index, updates)


Loading…
Cancel
Save