Browse Source

improve tensor setitem

tags/v1.3.0
yanglf1121 4 years ago
parent
commit
d3d38d2caa
3 changed files with 31 additions and 26 deletions
  1. +18
    -9
      mindspore/ops/composite/multitype_ops/_compile_utils.py
  2. +11
    -16
      mindspore/ops/composite/multitype_ops/_constexpr_utils.py
  3. +2
    -1
      mindspore/ops/functional.py

+ 18
- 9
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -438,8 +438,8 @@ def _tensor_getitem_by_tuple(data, tuple_index, op_name):
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape, index)
final_index_tensors.append(transform_tensor)
elif i in slice_positions:
slice_index_tensor = const_utils.convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape,
slice_shapes, fancy_position)
slice_index_tensor = convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape,
slice_shapes, fancy_position)
final_index_tensors.append(slice_index_tensor)
slice_cnt += 1

@@ -512,8 +512,8 @@ def _generate_indices_from_tuple(data, tuple_index, op_name, fancy_position):
transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape, index)
final_index_tensors.append(transform_tensor)
elif i in slice_positions:
slice_index_tensor = const_utils.convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape,
slice_shapes, fancy_position)
slice_index_tensor = convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape,
slice_shapes, fancy_position)
final_index_tensors.append(slice_index_tensor)
slice_cnt += 1

@@ -608,7 +608,7 @@ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
if F.rank(index) < 2:
index = F.expand_dims(index, 0)
updates = F.expand_dims(updates, 0)
return P.TensorScatterUpdate()(data, index, updates)
return F.tensor_scatter_update(data, index, updates)


def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
@@ -653,7 +653,7 @@ def _tensor_setitem_by_tensor_with_sequence(data, index, value):
"""Set a tensor item by a tensor with a tuple."""
updates = _generate_updates_from_sequence(data, index, value, const_utils.SET_ITEM_BY_ONE_TENSOR)
index = F.expand_dims(index, -1)
return P.TensorScatterUpdate()(data, index, updates)
return F.tensor_scatter_update(data, index, updates)


def tensor_setitem_by_slice_with_number(data, input_slice, value):
@@ -679,7 +679,7 @@ def tensor_setitem_by_slice_with_tensor(data, input_slice, value):
return data
value_shape = const_utils.tuple_slice(F.shape(indices), None, -1)
value = _broadcast(value_shape, value)
result = P.TensorScatterUpdate()(data, indices, value.astype(F.dtype(data)))
result = F.tensor_scatter_update(data, indices, value.astype(F.dtype(data)))
return result


@@ -711,7 +711,7 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
if indices is False:
return data
updates = _generate_updates_from_tensor(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates)
return F.tensor_scatter_update(data, indices, updates)


def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value):
@@ -737,7 +737,7 @@ def tensor_setitem_by_number_with_tensor(data, index, value):
index = const_utils.int_to_index(index, data_shape)
value_shape = const_utils.tuple_slice(F.shape(index), None, -1)
value = _broadcast(value_shape, value.astype(F.dtype(data)))
return P.TensorScatterUpdate()(data, index, value)
return F.tensor_scatter_update(data, index, value)


def tensor_setitem_by_ellipsis_with_number(data, value):
@@ -955,3 +955,12 @@ def check_indices(dims, indices, mode, allow_negative_index=True):
return clipped

tensor_operator_registry.register('check_indices', check_indices)


def convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, slice_shapes, fancy_position):
"""Convert a slice to a tensor."""
shape = const_utils.compute_slice_shape(slice_shapes, len(broadcast_shape), slice_cnt, fancy_position)
array = const_utils.make_tensor(index, mstype.int64).reshape(shape)
reps = const_utils.compute_multiples(shape, final_shape)
slice_index_tensor = F.tile(array, reps)
return slice_index_tensor

+ 11
- 16
mindspore/ops/composite/multitype_ops/_constexpr_utils.py View File

@@ -454,21 +454,6 @@ def compute_multiples(origin_shape, broadcast_shape):
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape))


@constexpr
def convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, slice_shapes, fancy_position):
"""Convert a slice to a tensor."""

shape = [1] * len(slice_shapes)
shape[slice_cnt] = slice_shapes[slice_cnt]
shape = shape[:fancy_position] + [1] * len(broadcast_shape) + shape[fancy_position:]

array = np.array(index, np.int64)
array = np.reshape(array, shape)
reps = compute_multiples(shape, final_shape)
slice_index_tensor = Tensor(np.tile(array, reps), mstype.int64)
return slice_index_tensor


@constexpr
def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type):
"""Convert a scalar to a tensor."""
@@ -492,7 +477,8 @@ def generate_updates_shape(data_shape, index_shape, op_type):
@constexpr
def transform_slice_to_ele_list(slice_index, dim_len):
slice_obj = slice(slice_index.start, slice_index.stop, slice_index.step)
slice_ele_list = list(range(dim_len))[slice_obj]
start, stop, end = normalize_slice(slice_obj, dim_len)
slice_ele_list = list(range(start, stop, end))
if not slice_ele_list:
raise IndexError(f"An empty slice is not supported, got {slice_obj}")
return slice_ele_list
@@ -805,3 +791,12 @@ def real_axes(ndim_orig, ndim_out, axes_orig):


check_axis_valid_const = constexpr(validator.check_axis_valid)


@constexpr
def compute_slice_shape(slice_shape, broadcast_shape_len, slice_cnt, fancy_position):
"""Computes slice tensor shapes"""
shape = [1] * len(slice_shape)
shape[slice_cnt] = slice_shape[slice_cnt]
shape = shape[:fancy_position] + [1] * broadcast_shape_len + shape[fancy_position:]
return shape

+ 2
- 1
mindspore/ops/functional.py View File

@@ -129,6 +129,7 @@ gather = P.Gather()
gather_d = P.GatherD()
gather_nd = P.GatherNd()
scatter_update = P.ScatterUpdate()
tensor_scatter_update = P.TensorScatterUpdate()
scatter_nd_update = P.ScatterNdUpdate()
stack = P.Stack()

@@ -211,7 +212,7 @@ switch = Primitive('Switch')
switch_layer = Primitive('switch_layer')
# for sum bprop
reduced_shape = Primitive("reduced_shape")
# shape_mul:input mush be shape multiply elemts in tuple(shape)
# shape_mul:input must be shape multiply elements in tuple(shape)
shape_mul = Primitive("shape_mul")
# a primitive to compare between tuple.
stop_gradient = Primitive("stop_gradient")


Loading…
Cancel
Save