|
|
|
@@ -423,10 +423,9 @@ def _tensor_setitem_by_bool_tensor_with_scalar(data, index, value): |
|
|
|
|
|
|
|
def _tensor_setitem_by_int_tensor_with_scalar(data, index, value): |
|
|
|
"""Set a tensor item by a int tensor with a scalar.""" |
|
|
|
updates = _generate_updates_from_scalar(data, index, value, |
|
|
|
const_utils.SET_ITEM_BY_ONE_TENSOR) |
|
|
|
index = F.expand_dims(index, -1) |
|
|
|
return P.TensorScatterUpdate()(data, index, updates) |
|
|
|
index = F.expand_dims(index, 0) |
|
|
|
updates = _generate_updates_from_scalar(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) |
|
|
|
return P.ScatterUpdate()(data, index, updates) |
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_tensor_with_number(data, index, value): |
|
|
|
|