|
|
|
@@ -438,8 +438,8 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name): |
|
|
|
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape, index) |
|
|
|
final_index_tensors.append(transform_tensor) |
|
|
|
elif i in slice_positions: |
|
|
|
slice_index_tensor = const_utils.convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, |
|
|
|
slice_shapes, fancy_position) |
|
|
|
slice_index_tensor = convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, |
|
|
|
slice_shapes, fancy_position) |
|
|
|
final_index_tensors.append(slice_index_tensor) |
|
|
|
slice_cnt += 1 |
|
|
|
|
|
|
|
@@ -512,8 +512,8 @@ def _generate_indices_from_tuple(data, tuple_index, op_name, fancy_position): |
|
|
|
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape, index) |
|
|
|
final_index_tensors.append(transform_tensor) |
|
|
|
elif i in slice_positions: |
|
|
|
slice_index_tensor = const_utils.convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, |
|
|
|
slice_shapes, fancy_position) |
|
|
|
slice_index_tensor = convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, |
|
|
|
slice_shapes, fancy_position) |
|
|
|
final_index_tensors.append(slice_index_tensor) |
|
|
|
slice_cnt += 1 |
|
|
|
|
|
|
|
@@ -608,7 +608,7 @@ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value): |
|
|
|
if F.rank(index) < 2: |
|
|
|
index = F.expand_dims(index, 0) |
|
|
|
updates = F.expand_dims(updates, 0) |
|
|
|
return P.TensorScatterUpdate()(data, index, updates) |
|
|
|
return F.tensor_scatter_update(data, index, updates) |
|
|
|
|
|
|
|
|
|
|
|
def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value): |
|
|
|
@@ -653,7 +653,7 @@ def _tensor_setitem_by_tensor_with_sequence(data, index, value): |
|
|
|
"""Set a tensor item by a tensor with a tuple.""" |
|
|
|
updates = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR) |
|
|
|
index = F.expand_dims(index, -1) |
|
|
|
return P.TensorScatterUpdate()(data, index, updates) |
|
|
|
return F.tensor_scatter_update(data, index, updates) |
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_slice_with_number(data, input_slice, value): |
|
|
|
@@ -679,7 +679,7 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value): |
|
|
|
return data |
|
|
|
value_shape = const_utils.tuple_slice(F.shape(indices), None, -1) |
|
|
|
value = _broadcast(value_shape, value) |
|
|
|
result = P.TensorScatterUpdate()(data, indices, value.astype(F.dtype(data))) |
|
|
|
result = F.tensor_scatter_update(data, indices, value.astype(F.dtype(data))) |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
@@ -711,7 +711,7 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): |
|
|
|
if indices is False: |
|
|
|
return data |
|
|
|
updates = _generate_updates_from_tensor(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) |
|
|
|
return P.TensorScatterUpdate()(data, indices, updates) |
|
|
|
return F.tensor_scatter_update(data, indices, updates) |
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value): |
|
|
|
@@ -737,7 +737,7 @@ def tensor_setitem_by_number_with_tensor(data, index, value): |
|
|
|
index = const_utils.int_to_index(index, data_shape) |
|
|
|
value_shape = const_utils.tuple_slice(F.shape(index), None, -1) |
|
|
|
value = _broadcast(value_shape, value.astype(F.dtype(data))) |
|
|
|
return P.TensorScatterUpdate()(data, index, value) |
|
|
|
return F.tensor_scatter_update(data, index, value) |
|
|
|
|
|
|
|
|
|
|
|
def tensor_setitem_by_ellipsis_with_number(data, value): |
|
|
|
@@ -955,3 +955,12 @@ def check_indices(dims, indices, mode, allow_negative_index=True): |
|
|
|
return clipped |
|
|
|
|
|
|
|
tensor_operator_registry.register('check_indices', check_indices) |
|
|
|
|
|
|
|
|
|
|
|
def convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, slice_shapes, fancy_position): |
|
|
|
"""Convert a slice to a tensor.""" |
|
|
|
shape = const_utils.compute_slice_shape(slice_shapes, len(broadcast_shape), slice_cnt, fancy_position) |
|
|
|
array = const_utils.make_tensor(index, mstype.int64).reshape(shape) |
|
|
|
reps = const_utils.compute_multiples(shape, final_shape) |
|
|
|
slice_index_tensor = F.tile(array, reps) |
|
|
|
return slice_index_tensor |