Browse Source

fancy index getitem: None

tags/v1.2.0-rc1
Payne 5 years ago
parent
commit
f56dd425c8
9 changed files with 323 additions and 527 deletions
  1. +1
    -1
      mindspore/nn/layer/basic.py
  2. +1
    -1
      mindspore/ops/composite/array_ops.py
  3. +10
    -6
      mindspore/ops/composite/math_ops.py
  4. +238
    -320
      mindspore/ops/composite/multitype_ops/_compile_utils.py
  5. +57
    -192
      mindspore/ops/composite/multitype_ops/_constexpr_utils.py
  6. +1
    -1
      mindspore/ops/composite/multitype_ops/getitem_impl.py
  7. +2
    -0
      mindspore/ops/composite/multitype_ops/setitem_impl.py
  8. +9
    -3
      mindspore/ops/composite/random_ops.py
  9. +4
    -3
      mindspore/ops/operations/array_ops.py

+ 1
- 1
mindspore/nn/layer/basic.py View File

@@ -84,7 +84,7 @@ class L1Regularizer(Cell):
self.scale = Tensor(scale, dtype=mstype.float32) self.scale = Tensor(scale, dtype=mstype.float32)


def construct(self, weights): def construct(self, weights):
const_utils.check_valid_type(F.dtype(weights), mstype.number_type, 'weights')
const_utils.check_type_valid(F.dtype(weights), mstype.number_type, 'weights')
l1_regularization = self.scale * self.reduce_sum(self.abs(weights)) l1_regularization = self.scale * self.reduce_sum(self.abs(weights))
return l1_regularization return l1_regularization




+ 1
- 1
mindspore/ops/composite/array_ops.py View File

@@ -82,7 +82,7 @@ def repeat_elements(x, rep, axis=0):
[3 4 5] [3 4 5]
[3 4 5]] [3 4 5]]
""" """
const_utils.check_valid_type(F.dtype(x), mstype.number_type, 'input x')
const_utils.check_type_valid(F.dtype(x), mstype.number_type, 'input x')
rep = _check_positive_int(rep, "rep", "repeat_elements") rep = _check_positive_int(rep, "rep", "repeat_elements")
axis = _check_is_int(axis, "axis", "repeat_elements") axis = _check_is_int(axis, "axis", "repeat_elements")




+ 10
- 6
mindspore/ops/composite/math_ops.py View File

@@ -22,6 +22,8 @@ from mindspore.ops import functional as F
from .. import operations as P from .. import operations as P


# count_nonzero # count_nonzero


@constexpr @constexpr
def _check_validate_axis(axis, name): def _check_validate_axis(axis, name):
if isinstance(axis, (tuple, list)): if isinstance(axis, (tuple, list)):
@@ -63,10 +65,10 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
[[3]] [[3]]
""" """


const_utils.check_valid_type(F.dtype(x), mstype.number_type, 'input x')
const_utils.check_type_valid(F.dtype(x), mstype.number_type, 'input x')
axis = _check_validate_axis(axis, "count_nonzero") axis = _check_validate_axis(axis, "count_nonzero")
keep_dims = _check_validate_keepdims(keep_dims, "count_nonzero") keep_dims = _check_validate_keepdims(keep_dims, "count_nonzero")
const_utils.check_valid_type(dtype, mstype.number_type + (mstype.bool_,), 'dtype')
const_utils.check_type_valid(dtype, mstype.number_type + (mstype.bool_,), 'dtype')


not_equal = P.NotEqual() not_equal = P.NotEqual()
cast = P.Cast() cast = P.Cast()
@@ -79,6 +81,8 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
return nonzero_num return nonzero_num


# tensor dot # tensor dot


@constexpr @constexpr
def _int_to_tuple_conv(axes): def _int_to_tuple_conv(axes):
""" """
@@ -97,10 +101,10 @@ def _check_axes(axes):
""" """
validator.check_value_type('axes', axes, [int, tuple, list], "tensor dot") validator.check_value_type('axes', axes, [int, tuple, list], "tensor dot")
if not isinstance(axes, int): if not isinstance(axes, int):
axes = list(axes) # to avoid immutability issues
axes = list(axes) # to avoid immutability issues
if len(axes) != 2: if len(axes) != 2:
raise ValueError("Require two axes inputs, given less") raise ValueError("Require two axes inputs, given less")
axes = _int_to_tuple_conv(axes) # convert before length checks
axes = _int_to_tuple_conv(axes) # convert before length checks
if len(axes[0]) != len(axes[1]): if len(axes[0]) != len(axes[1]):
raise ValueError("Axes have to be the same size/length") raise ValueError("Axes have to be the same size/length")
if len(axes[0]) != len(set(axes[0])) or len(axes[1]) != len(set(axes[1])): if len(axes[0]) != len(set(axes[0])) or len(axes[1]) != len(set(axes[1])):
@@ -113,8 +117,8 @@ def _typecheck_input(x1_type, x2_type):
""" """
Check input tensor types to be valid and confirm they are the same type. Check input tensor types to be valid and confirm they are the same type.
""" """
const_utils.check_valid_type(x1_type, [mstype.float32, mstype.float16], 'x1')
const_utils.check_valid_type(x2_type, [mstype.float32, mstype.float16], 'x2')
const_utils.check_type_valid(x1_type, [mstype.float32, mstype.float16], 'x1')
const_utils.check_type_valid(x2_type, [mstype.float32, mstype.float16], 'x2')
if x1_type != x2_type: if x1_type != x2_type:
raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ') raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ')




+ 238
- 320
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -26,6 +26,66 @@ hyper_map = base.HyperMap()
pack = P.Pack(axis=-1) pack = P.Pack(axis=-1)




def _tensor_getitem(self, index):
"""Handle tensor getitem"""
if isinstance(index, Tensor):
return tensor_index_by_tensor(self, index)
if isinstance(index, list):
return tensor_index_by_list(self, index)
if isinstance(index, tuple):
return tensor_index_by_tuple(self, index)
# bool type should be judged before int
if isinstance(index, bool):
return _tensor_index_by_bool(self, index)
if isinstance(index, int):
return _tensor_index_by_integer(self, index)
if isinstance(index, slice):
return tensor_index_by_slice(self, index)
if index is None:
return F.expand_dims(self, 0)
if index is ...:
return self
raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor with int, "
f"list and tuple ,but got {index} with type {type(index)}.")


def _tensor_setitem(self, index, value):
"""Handle tensor getitem"""
if isinstance(index, Tensor):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_tensor_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_tensor_with_tensor(self, index, value)
if isinstance(value, tuple):
return tensor_setitem_by_tensor_with_tuple(self, index, value)
if isinstance(index, tuple):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_tuple_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_tuple_with_tensor(self, index, value)
if isinstance(value, tuple):
return tensor_setitem_by_tuple_with_tuple(self, index, value)
if isinstance(index, int):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_number_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_number_with_tensor(self, index, value)
if isinstance(index, slice):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_slice_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_slice_with_tensor(self, index, value)
if isinstance(index, bool):
return _tensor_index_by_bool(self, index)
if index is ...:
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_ellipsis_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_ellipsis_with_tensor(self, index, value)
raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\
and tensor with int32, got {} with type{}".format(index, type(index)))


def _broadcast(broadcast_shape, x): def _broadcast(broadcast_shape, x):
"""Broadcast tensor to the required shape.""" """Broadcast tensor to the required shape."""
if F.shape(x) == broadcast_shape: if F.shape(x) == broadcast_shape:
@@ -42,15 +102,21 @@ def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x):
return _broadcast(final_shape, F.reshape(x, new_shape)) return _broadcast(final_shape, F.reshape(x, new_shape))




def _transform_ellipsis_to_slice(tuple_index, data, op_name):
"""transform ellipsis in the slice to several slice"""
def _transform_ellipsis_to_slice(data, tuple_index, op_name):
"""Check if the tuple index len is longer than the data's dims and transform ellipsis in the indices
to several slice"""
data_shape = F.shape(data) data_shape = F.shape(data)
data_rank = len(data_shape) data_rank = len(data_shape)
indexes_types = hyper_map(F.typeof, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index)
slice_positions, ellipsis_positions, _, int_positions, _, tensor_positions, sequence_positions = \ slice_positions, ellipsis_positions, _, int_positions, _, tensor_positions, sequence_positions = \
const_utils.get_pos_of_indexes_types(indexes_types, op_name) const_utils.get_pos_of_indexes_types(indexes_types, op_name)

ellipsis_occupy_dims = data_rank - (len(slice_positions) + len(int_positions) + ellipsis_occupy_dims = data_rank - (len(slice_positions) + len(int_positions) +
len(tensor_positions) + len(sequence_positions)) len(tensor_positions) + len(sequence_positions))
ellipsis_cnt = len(ellipsis_positions)
if (ellipsis_cnt == 0 and ellipsis_occupy_dims < 0) or (ellipsis_cnt > 0 and ellipsis_occupy_dims < 1):
const_utils.raise_index_error("For the 'getitem Operator', the data_shape should be no less than the "
"tuple index dims")


tuple_index_new = () tuple_index_new = ()
for i, index in enumerate(tuple_index): for i, index in enumerate(tuple_index):
@@ -63,122 +129,172 @@ def _transform_ellipsis_to_slice(tuple_index, data, op_name):
return tuple_index_new return tuple_index_new




def _expand_data_dims_with_none(data, tuple_index, op_name):
"""expand the data's dim with 'None' in tuple_index"""
indexes_types = hyper_map(F.typeof, tuple_index)
none_positions, tuple_index_without_none = (), ()
for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)):
none_type_tag = const_utils.judge_index_type(index_type, mstype.type_none)
tuple_index_without_none += (const_utils.make_empty_slice(),) if none_type_tag else(index,)
none_positions += (i,) if none_type_tag else ()

for dim in none_positions:
data = F.expand_dims(data, dim)

return data, tuple_index_without_none


def tensor_index_by_slice(data, slice_index):
"""Tensor getitem by a single slice"""
shape = F.shape(data)
if not shape:
const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor"
"cannot be 0.")
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(shape, slice_index)
return F.strided_slice(data, begin_strides, end_strides, step_strides)


def tensor_index_by_number(data, number):
"""Tensor getitem by a Number which may be integer/float/bool value"""
number_type = const_utils.check_number_index_type(number)
if number_type == const_utils.BOOL_:
return _tensor_index_by_bool(data, number)
if number_type == const_utils.INT_:
return _tensor_index_by_integer(data, number)
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.")


def _tensor_index_by_bool(data, bool_value):
"""Tensor getitem by a single bool value"""
if bool_value:
return F.expand_dims(data, 0)
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")


def _tensor_index_by_integer(data, number):
"""Tensor getitem by a single integer number"""
data_shape = F.shape(data)
data_rank = len(data_shape)
if data_rank == 0:
return const_utils.raise_type_error("When tensor is indexed by an integer, the dimension of the tensor "
"cannot be 0.")
transformed_number = const_utils.check_and_transform_int_index(number, data_shape[0], const_utils.TENSOR_GETITEM)
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(data_shape, transformed_number)
shrink_axis_mask = 1
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)


def tensor_index_by_tensor(data, tensor_index):
"""Tensor getitem by a single tensor"""
index_type = F.dtype(tensor_index)
const_utils.check_index_type_valid(index_type, mstype.int_type, const_utils.TENSOR_GETITEM)
tensor_index = F.cast(tensor_index, mstype.int64)
return F.gather(data, tensor_index, 0)


def tensor_index_by_list(data, list_index):
"""Tensor getitem by list of int and bool"""
data_shape = F.shape(data)
const_utils.check_sequence_index_type(list_index, const_utils.TENSOR_GETITEM)
sub_tuple_index = const_utils.transform_sequence_index(list_index, data_shape[0], const_utils.TENSOR_GETITEM)
tensor_index = F.tuple_to_array(sub_tuple_index)
tensor_index = F.cast(tensor_index, mstype.int64)
return F.gather(data, tensor_index, 0)


def tensor_index_by_tuple(data, tuple_index):
"""Tensor getitem by tuple of various types with None"""
op_name = const_utils.TENSOR_GETITEM
if len(tuple_index) == 1:
return data[tuple_index[0]]
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name)
indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
if contain_type == const_utils.ALL_TENSOR:
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
if contain_type == const_utils.ALL_BASIC:
return _tensor_getitem_by_tuple_slice(data, tuple_index)
return _tensor_getitem_by_tuple(data, tuple_index)


def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
"""Tensor getitem by a tuple of tensor."""
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result


def _tensor_getitem_by_tuple_slice(data, tuple_index):
"""Tensor getitem by a tuple of slice"""
data_shape = F.shape(data)
begin_strides, end_strides, step_strides, shrink_axis_mask = \
const_utils.get_stride_info_from_tuple(data_shape, tuple_index)
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)


def _tensor_getitem_by_tuple(data, tuple_index):
"""Tensor getitem by a tuple of mixed tensor."""
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result


def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple of tensor.""" """Generate an indices tensor from a tuple of tensor."""
indices = None indices = None
check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name)
if check_index_tensor_number:
dtype_tuple = hyper_map(F.dtype, tuple_index)
check_dtypes = const_utils.check_index_tensors_dtype(dtype_tuple, op_name)
if check_dtypes:
shape_tuple = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name)
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
indices = pack(broadcast_tensors)
indexes_types = hyper_map(F.dtype, tuple_index)
const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name)
tensor_index_shape = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
indices = pack(broadcast_tensors)
indices = F.cast(indices, mstype.int64)
return indices return indices




def _generate_indices_from_tuple(data, tuple_index, op_name): def _generate_indices_from_tuple(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor.""" """Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
data_shape = F.shape(data) data_shape = F.shape(data)
tuple_index_len = len(tuple_index)
tensor_indexes, slice_indexes = [], []
indexes_types = hyper_map(F.typeof, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index)
int_positions, sequence_positions = const_utils.get_pos_of_int_sequence(indexes_types)
slice_positions, _, _, int_positions, _, \
tensor_positions, sequence_positions = const_utils.get_pos_of_indexes_types(indexes_types, op_name)
tuple_index_new = () tuple_index_new = ()
tuple_len = len(tuple_index)


for i in range(tuple_len):
index = tuple_index[i]
shape = data_shape[i]
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
if i in int_positions: if i in int_positions:
int_index = const_utils.check_and_transform_int_index(index, shape, op_name)
tensor_index = F.scalar_to_tensor(int_index, mstype.int32)
int_index = const_utils.check_and_transform_int_index(index, dim_size, op_name)
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
tuple_index_new += (tensor_index,) tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
tensor_positions.append(i)
elif i in sequence_positions: elif i in sequence_positions:
sequence_index = const_utils.transform_sequence_index(index, shape, op_name)
sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name)
tensor_index = F.tuple_to_array(sequence_index) tensor_index = F.tuple_to_array(sequence_index)
tensor_index = F.cast(tensor_index, mstype.int64)
tuple_index_new += (tensor_index,) tuple_index_new += (tensor_index,)
else:
tensor_indexes.append(tensor_index)
tensor_positions.append(i)
elif i in tensor_positions:
tensor_index = F.cast(index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
elif i in slice_positions:
slice_indexes.append(index)
tuple_index_new += (index,) tuple_index_new += (index,)


indexes_types = hyper_map(F.typeof, tuple_index_new)
tensor_positions, slice_positions, ellipsis_position = \
const_utils.separate_mixed_tensors_index(indexes_types, op_name)
tensor_indexes, slice_indexes = [], []
for i in tensor_positions:
tensor_indexes.append(tuple_index_new[i])
for j in slice_positions:
slice_indexes.append(tuple_index_new[j])

tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \
const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape,
indexes_types,
tensor_indexes_shapes,
tensor_indexes_dtypes,
slice_indexes,
op_name)

slice_number = 0
final_index_tensors = []
tuple_index_size = len(tuple_index_new)
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_size):
if i in tensor_positions:
transform_tensor = _transform_indexing_tensor(
broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i])
final_index_tensors.append(transform_tensor)
if i in slice_positions:
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
final_index_tensors.append(slice_tensor)
slice_number += 1
if i == ellipsis_position:
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(
slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name)
for ele in ellipsis_tensors:
final_index_tensors.append(ele)
slice_number += ellipsis_occupied_dims
indices = pack(final_index_tensors)
return indices


def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
data_shape = F.shape(data)
indexes_types = hyper_map(F.typeof, tuple_index)
int_positions = const_utils.get_pos_of_int_index(indexes_types)
tuple_index_new = ()
tuple_len = len(tuple_index)
for i in range(tuple_len):
if i in int_positions:
tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] +
data_shape[i], mstype.int32),)
else:
tuple_index_new += (tuple_index[i],)
indexes_types = hyper_map(F.typeof, tuple_index_new) indexes_types = hyper_map(F.typeof, tuple_index_new)
tensor_positions, slice_positions, ellipsis_position = \
const_utils.separate_mixed_tensors_index(indexes_types, op_name)
tensor_indexes = []
slice_indexes = []
for i in tensor_positions:
tensor_indexes.append(tuple_index_new[i])
for j in slice_positions:
slice_indexes.append(tuple_index_new[j])
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \
const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape,
indexes_types,
tensor_indexes_shapes,
tensor_indexes_dtypes,
slice_indexes,
op_name)
broadcast_shape, final_shape, indexes_shapes_info = const_utils.generate_index_info_from_tuple_of_mixed_tensors(
data_shape, indexes_types, tensor_indexes_shapes, tensor_indexes_dtypes, slice_indexes, op_name)


slice_number = 0 slice_number = 0
final_index_tensors = [] final_index_tensors = []
tuple_index_size = len(tuple_index_new)
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_size):
for i in range(tuple_index_len):
if i in tensor_positions: if i in tensor_positions:
transform_tensor = _transform_indexing_tensor( transform_tensor = _transform_indexing_tensor(
broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i]) broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i])
@@ -187,12 +303,7 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name) slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
final_index_tensors.append(slice_tensor) final_index_tensors.append(slice_tensor)
slice_number += 1 slice_number += 1
if i == ellipsis_position:
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(
slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name)
for ele in ellipsis_tensors:
final_index_tensors.append(ele)
slice_number += ellipsis_occupied_dims

indices = pack(final_index_tensors) indices = pack(final_index_tensors)
return indices return indices


@@ -239,179 +350,8 @@ def _generate_updates_from_tensor(data, index, value, op_type):
return value return value




def _tensor_getitem(self, index):
"""Handle tensor getitem"""
if isinstance(index, Tensor):
return tensor_index_by_tensor(self, index)
if isinstance(index, tuple):
return tensor_index_by_tuple(self, index)
if isinstance(index, list):
return tensor_index_by_list(self, index)
# bool type should be judged before int
if isinstance(index, bool):
return _tensor_index_by_bool(self, index)
if isinstance(index, int):
return _tensor_index_by_integer(self, index)
if isinstance(index, slice):
return tensor_index_by_slice(self, index)
if index is None:
return F.expand_dims(self, 0)
if index is ...:
return self
raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32, "
f"got {index} with type {type(index)}.")


tensor_operator_registry.register("__getitem__", _tensor_getitem) tensor_operator_registry.register("__getitem__", _tensor_getitem)



def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
"""Tensor getitem by a tuple of tensor."""
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result


def _tensor_getitem_by_tuple(data, tuple_index):
"""Tensor getitem by a tuple of mixed tensor."""
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result


def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
"""Tensor getitem by a tuple of mixed tensor."""
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result


def tensor_index_by_slice(data, slice_index):
"""Tensor getitem by a single slice"""
shape = F.shape(data)
if not shape:
const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor"
"cannot be 0.")
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(shape, slice_index)
return F.strided_slice(data, begin_strides, end_strides, step_strides)


def _tensor_index_by_integer(data, number):
"""Tensor getitem by a single integer number"""
shape = F.shape(data)
if not shape:
return const_utils.raise_type_error("When tensor is indexed by an integer,"
"the dimension of the tensor cannot be 0.")
if number >= shape[0]:
return const_utils.raise_index_error("index {} is out of bounds for axis 0 with size {}".format(
number, shape[0]))
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(shape, number)
shrink_axis_mask = 1
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)


def _tensor_index_by_bool(data, bool_value):
"""Tensor getitem by a single bool value"""
if bool_value:
return F.expand_dims(data, 0)
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")


def tensor_index_by_number(data, number):
"""Tensor getitem by a Number which may be integer/float/bool value"""
number_type = const_utils.check_number_index_type(number)
if number_type == const_utils.BOOL_:
return _tensor_index_by_bool(data, number)
if number_type == const_utils.INT_:
return _tensor_index_by_integer(data, number)
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.")


def tensor_index_by_tensor(data, tensor_index):
"""Tensor getitem by a single tensor"""
dtype_valid = const_utils.check_index_tensor_dtype(F.dtype(tensor_index),
const_utils.TENSOR_GETITEM)
if dtype_valid:
return F.gather(data, tensor_index, 0)
return const_utils.raise_index_error("For 'tensor getitem', "
"the index tensor data type only support mstype.int32.")


def _tensor_index_by_tuple_slice(data, tuple_index):
"""Tensor getitem by a tuple of slice"""
data_shape = F.shape(data)
if len(tuple_index) > len(data_shape):
const_utils.raise_index_error("When tensor is indexed by a tuple, the length of the tuple cannot "
"be greater than the dimension of the tensor.")
begin_strides, end_strides, step_strides, shrink_axis_mask = \
const_utils.get_stride_info_from_tuple(data_shape, tuple_index)
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)


def tensor_index_by_list(data, list_index):
"""Tensor getitem by list of int and bool"""
data_shape = F.shape(data)
const_utils.check_list_index_type(list_index)
list_index = const_utils.transform_list(list_index, data_shape[0])
tensor_index = const_utils.convert_list_to_tensor(list_index)
return F.gather(data, tensor_index, 0)


def tensor_index_by_tuple(data, tuple_index):
"""Tensor getitem by tuple of various types with None"""
if len(tuple_index) == 1:
return data[tuple_index[0]]
tuple_index = _transform_ellipsis_to_slice(tuple_index, data, const_utils.TENSOR_GETITEM)
indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_GETITEM)
if contain_type == const_utils.ALL_TENSOR:
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
if contain_type == const_utils.ALL_BASIC:
return _tensor_index_by_tuple_slice(data, tuple_index)
return _tensor_getitem_by_tuple(data, tuple_index)


def _tensor_setitem(self, index, value):
"""Handle tensor getitem"""
if isinstance(index, Tensor):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_tensor_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_tensor_with_tensor(self, index, value)
if isinstance(value, tuple):
return tensor_setitem_by_tensor_with_tuple(self, index, value)
if isinstance(index, tuple):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_tuple_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_tuple_with_tensor(self, index, value)
if isinstance(value, tuple):
return tensor_setitem_by_tuple_with_tuple(self, index, value)
if isinstance(index, int):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_number_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_number_with_tensor(self, index, value)
if isinstance(index, slice):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_slice_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_slice_with_tensor(self, index, value)
if isinstance(index, bool):
return _tensor_index_by_bool(self, index)
if index is ...:
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_ellipsis_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_ellipsis_with_tensor(self, index, value)
raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\
and tensor with int32, got {} with type{}".format(index, type(index)))


tensor_operator_registry.register("__setitem__", _tensor_setitem) tensor_operator_registry.register("__setitem__", _tensor_setitem)




@@ -532,24 +472,21 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
if len(tuple_index) == 1: if len(tuple_index) == 1:
data[tuple_index[0]] = value data[tuple_index[0]] = value
return data return data
op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name)

indexes_types = hyper_map(F.typeof, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)


if index_elements_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
if contain_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM)
else: else:
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if int_cnt == const_utils.ALL_INT: if int_cnt == const_utils.ALL_INT:
tuple_index = const_utils.convert_int_to_slice(tuple_index) tuple_index = const_utils.convert_int_to_slice(tuple_index)
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_scalar(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_scalar(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates) return P.TensorScatterUpdate()(data, indices, updates)




@@ -597,42 +534,26 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
if len(tuple_index) == 1: if len(tuple_index) == 1:
data[tuple_index[0]] = value data[tuple_index[0]] = value
return data return data
data_shape = data.shape
tuple_index_new = ()
for i, index in enumerate(tuple_index):
if isinstance(index, mstype.Int):
if index < -data_shape[i] or index >= data_shape[i]:
const_utils.raise_index_error("The index is out of the data's special dimension range.")
elif index < 0:
tuple_index_new += (tuple_index[i]+data_shape[i],)
else:
tuple_index_new += (tuple_index[i],)
else:
tuple_index_new += (tuple_index[i],)
op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name)


indexes_types = hyper_map(F.typeof, tuple_index_new)
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM)
indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)


if index_elements_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index_new,
const_utils.TENSOR_SETITEM)
if contain_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM)
else: else:
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if int_cnt == const_utils.ALL_INT: if int_cnt == const_utils.ALL_INT:
tuple_index_new = const_utils.convert_int_to_slice(tuple_index_new)
tuple_index = const_utils.convert_int_to_slice(tuple_index)
new_shape = () new_shape = ()
for _ in tuple_index_new:
for _ in tuple_index:
new_shape += (1,) new_shape += (1,)
new_shape += value.shape new_shape += value.shape
value = F.reshape(value, new_shape) value = F.reshape(value, new_shape)
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index_new,
const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_tensor(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_tensor(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates) return P.TensorScatterUpdate()(data, indices, updates)




@@ -641,24 +562,21 @@ def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
if len(tuple_index) == 1: if len(tuple_index) == 1:
data[tuple_index[0]] = value data[tuple_index[0]] = value
return data return data
op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name)

indexes_types = hyper_map(F.typeof, tuple_index) indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)


if index_elements_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
if contain_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM)
else: else:
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if int_cnt == const_utils.ALL_INT: if int_cnt == const_utils.ALL_INT:
tuple_index = const_utils.convert_int_to_slice(tuple_index) tuple_index = const_utils.convert_int_to_slice(tuple_index)
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_tuple(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_tuple(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates) return P.TensorScatterUpdate()(data, indices, updates)






+ 57
- 192
mindspore/ops/composite/multitype_ops/_constexpr_utils.py View File

@@ -69,16 +69,8 @@ def check_equal(param1, param2, msg="{},{}"):




@constexpr @constexpr
def split_tuple_index_for_none(tuple_index):
"""return the none_positions and the tuple_index_without_none whose None index is replaced by slice."""
none_positions, tuple_index_without_none = (), ()
for idx, item in enumerate(tuple_index):
if item is None:
none_positions += (idx,)
tuple_index_without_none += (slice(None, None, None),)
else:
tuple_index_without_none += (item,)
return none_positions, tuple_index_without_none
def make_empty_slice():
return slice(None, None, None)




@constexpr @constexpr
@@ -139,10 +131,31 @@ def check_valid_dim(dim, name):




@constexpr @constexpr
def check_valid_type(data_type, value_type, name):
if not data_type in value_type:
raise TypeError(
f"For {name}, valid type include {value_type}, {data_type} is invalid")
def judge_index_type(index_type, target_type):
if index_type == target_type or (isinstance(target_type, (list, tuple)) and index_type in target_type):
return True
return False


@constexpr
def check_type_valid(dtype, target_type, op_name):
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
raise TypeError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")


@constexpr
def check_index_type_valid(dtype, target_type, op_name):
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
raise IndexError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")


@constexpr
def check_indexes_types_valid(dtypes, target_type, op_name):
"""Check a tuple of tensor data type."""
for dtype in dtypes:
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
raise IndexError(f"For '{op_name}', the all index tensor data types should be in {target_type}, "
f"but got {dtype}.")




def slice_expand(input_slices, shape): def slice_expand(input_slices, shape):
@@ -156,9 +169,7 @@ def slice_expand(input_slices, shape):
Outputs: Outputs:
tuple[list], This is expressed as (begins, ends, strides). tuple[list], This is expressed as (begins, ends, strides).
""" """
begin = []
end = []
strides = []
begin, end, strides = [], [], []
index = 0 index = 0
slices = None slices = None
# Slice or tuple(Slice...) # Slice or tuple(Slice...)
@@ -269,19 +280,6 @@ def integer_to_indices(index, shape):
return Tensor(value, dtype=mstype.int32) return Tensor(value, dtype=mstype.int32)




@constexpr
def tuple_element_is_slice(indexs):
"""Judges tuple element type."""
if not indexs:
raise IndexError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple):
for _, ele in enumerate(indexs):
if not isinstance(ele, slice):
return False
return True
return False


@constexpr @constexpr
def tuple_element_is_int(indexs): def tuple_element_is_int(indexs):
"""Judges tuple element type.""" """Judges tuple element type."""
@@ -395,8 +393,7 @@ def generate_broadcast_shape(shapes, op_name):
for i, shape in enumerate(shapes): for i, shape in enumerate(shapes):
logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
try: try:
broadcast_shape = op_utils.get_broadcast_shape(
broadcast_shape, shape, op_name)
broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name)
except ValueError as ex: except ValueError as ex:
raise IndexError(ex) raise IndexError(ex)
return tuple(broadcast_shape) return tuple(broadcast_shape)
@@ -439,80 +436,17 @@ def compute_new_shape(origin_shape, indexes_shapes_info):




@constexpr @constexpr
def check_list_index_type(list_index):
def check_sequence_index_type(sequence_index, op_name):
"""check if the item's type of list_index is bool or int""" """check if the item's type of list_index is bool or int"""
if not all([isinstance(index, (int, bool)) for index in list_index]):
raise IndexError(
f"Tensor only support 'integer' or 'boolean' array(list/tuple), but got {type(index)} in array")
if not all([isinstance(index, (int, bool)) for index in sequence_index]):
raise IndexError(f"In the {op_name} operation, only support 'integer' or 'boolean' array(list/tuple), "
f"but got {type(index)} in array")




@constexpr @constexpr
def transform_list(list_index, shape):
"""transfor list_index from int or bool to int"""
bool_count = len(list(filter(lambda index: isinstance(index, bool), list_index)))
int_count = len(list(filter(lambda index: isinstance(index, int), list_index)))-bool_count
if int_count == 0:
if bool_count == shape:
list_index = list(filter(lambda i: list_index[i], range(bool_count)))
else:
raise IndexError("The boolean array should have the same length with the corresponding dimensiton")
else:
list_index = [int(index) for index in list_index]
for i, index in enumerate(list_index):
if index < -shape or index >= shape:
raise IndexError(f"The index should in the range [-{shape}, {shape-1}] to fit the corresponding dim "
f"length, but get {index}.")
if index < 0:
index += shape
list_index[i] = index
return list_index


@constexpr
def convert_list_to_tensor(list_index):
"""convert the list_index to tensor_index with mstype.int64 dtype"""
return Tensor(list_index, mstype.int64)


@constexpr
def convert_int_to_slice(tuple_indexes):
tuple_indexes_new = tuple(slice(i, i+1, 1) for i in tuple_indexes)
return tuple_indexes_new


@constexpr
def convert_ellipsis_to_tensors(slice_number,
ellipsis_occupied_dims,
final_shape,
indexes_shapes_info,
op_name):
"""Convert an ellipsis to a list of tensor."""
tensor_list = []
dims_dealt_count = 0
while dims_dealt_count < ellipsis_occupied_dims:
shape = []
slice_count = 0
array = None
for ele in indexes_shapes_info:
if isinstance(ele, list):
if slice_count == slice_number:
array = np.array(ele, np.int32)
shape.append(len(ele))
else:
shape.append(1)
slice_count += 1
if isinstance(ele, tuple):
shape.extend([1] * len(ele))
if array is None:
raise ValueError(
f"For '{op_name}', generate tensors from ellipsis failed.")
array = np.reshape(array, shape)
reps = compute_multiples(shape, final_shape)
tensor = Tensor(np.tile(array, reps))
tensor_list.append(tensor)
slice_number += 1
dims_dealt_count += 1
return tensor_list
def convert_int_to_slice(tuple_index):
tuple_index_new = tuple(slice(i, i+1, 1) for i in tuple_index)
return tuple_index_new




@constexpr @constexpr
@@ -567,7 +501,7 @@ def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_n
f"For '{op_name}', generate tensor from 'slice' failed.") f"For '{op_name}', generate tensor from 'slice' failed.")
array = np.reshape(array, shape) array = np.reshape(array, shape)
reps = compute_multiples(shape, final_shape) reps = compute_multiples(shape, final_shape)
tensor = Tensor(np.tile(array, reps))
tensor = Tensor(np.tile(array, reps), mstype.int64)
return tensor return tensor




@@ -617,21 +551,17 @@ def generate_updates_shape(data_shape, index_shape, op_type):




@constexpr @constexpr
def check_number_of_index_tensor(data_shape, tuple_len, op_name):
def check_tuple_index_len(data_rank, tuple_index_len, op_name):
"""Check if the number of index tensor exceeds the dimension of the operated tensor.""" """Check if the number of index tensor exceeds the dimension of the operated tensor."""
if tuple_len <= len(data_shape):
if tuple_index_len <= data_rank:
return True return True
raise IndexError(f"For '{op_name}', the number {tuple_len} of index tensor "
f"is greater than the dimension {len(data_shape)} of the operated tensor.")
raise IndexError(f"For '{op_name}', the number {tuple_index_len} of tuple_index size"
f"is greater than the dimension {data_rank} of the operated tensor.")




@constexpr @constexpr
def generate_index_info_from_tuple_of_mixed_tensors(data_shape,
indexes_types,
tensor_indexes_shapes,
tensor_indexes_dtypes,
slice_indexes,
op_name):
def generate_index_info_from_tuple_of_mixed_tensors(data_shape, indexes_types, tensor_indexes_shapes,
tensor_indexes_dtypes, slice_indexes, op_name):
""" """
Generate index info which contain broadcast shape, final shape, Generate index info which contain broadcast shape, final shape,
indexes shapes info, ellipsis size from a tuple of mixed tensors. indexes shapes info, ellipsis size from a tuple of mixed tensors.
@@ -642,22 +572,14 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape,
if indexes_size > data_rank: if indexes_size > data_rank:
raise IndexError(f"For '{op_name}', the number {indexes_size} of index elements " raise IndexError(f"For '{op_name}', the number {indexes_size} of index elements "
f"is greater than the dimension {len(data_shape)} of the operated tensor.") f"is greater than the dimension {len(data_shape)} of the operated tensor.")
indexes_info = {}
index_tensors_info = {}
ellipsis_num = 0
ellipsis_occupied_dims = 0
tensor_count = 0
slice_count = 0
for i, ele_type in enumerate(indexes_types):
if ellipsis_num == 0:
pos = i
else:
pos = i + ellipsis_occupied_dims - 1
if isinstance(ele_type, mstype.tensor_type):
indexes_info, index_tensors_info = {}, {}
tensor_count, slice_count = 0, 0
for pos, index_type in enumerate(indexes_types):
if isinstance(index_type, mstype.tensor_type):
indexes_info[pos] = tensor_indexes_shapes[tensor_count] indexes_info[pos] = tensor_indexes_shapes[tensor_count]
index_tensors_info[pos] = tensor_indexes_shapes[tensor_count] index_tensors_info[pos] = tensor_indexes_shapes[tensor_count]
tensor_count += 1 tensor_count += 1
elif isinstance(ele_type, mstype.slice_type):
elif isinstance(index_type, mstype.slice_type):
slice_obj = slice(slice_indexes[slice_count].start, slice_obj = slice(slice_indexes[slice_count].start,
slice_indexes[slice_count].stop, slice_indexes[slice_count].stop,
slice_indexes[slice_count].step) slice_indexes[slice_count].step)
@@ -669,22 +591,12 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape,
slice_indexes[slice_count].stop, slice_indexes[slice_count].stop,
slice_indexes[slice_count].step)) slice_indexes[slice_count].step))
slice_count += 1 slice_count += 1
elif isinstance(ele_type, mstype.ellipsis_type):
if ellipsis_num != 0:
raise IndexError(
f"For '{op_name}', the index could only contain one ellipsis.")
ellipsis_occupied_dims = data_rank - indexes_size + 1
for j in range(pos, pos + ellipsis_occupied_dims):
# Use list to represent slicing result.
indexes_info[j] = list(range(data_shape[j]))
ellipsis_num += 1
else: else:
raise IndexError(f"For '{op_name}', the index elements only support " raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {ele_type}.")
broadcast_shape, final_shape, indexes_shapes_info = \
_derive_result_shape_info_from_tuple_of_mixed_tensors(
indexes_info, index_tensors_info, op_name)
return broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {index_type}.")
broadcast_shape, final_shape, indexes_shapes_info = _derive_result_shape_info_from_tuple_of_mixed_tensors(
indexes_info, index_tensors_info, op_name)
return broadcast_shape, final_shape, indexes_shapes_info




def _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key: list): def _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key: list):
@@ -701,8 +613,7 @@ def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_te
index_tensor_info_value = list(index_tensors_info.values()) index_tensor_info_value = list(index_tensors_info.values())
broadcast_shape = generate_broadcast_shape( broadcast_shape = generate_broadcast_shape(
index_tensor_info_value, op_name) index_tensor_info_value, op_name)
final_shape = []
indexes_shapes_info = []
final_shape, indexes_shapes_info = [], []
mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous( mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous(
index_tensor_info_key) index_tensor_info_key)
if mixed_tensors_continuous: if mixed_tensors_continuous:
@@ -734,54 +645,6 @@ def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_te
return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info) return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info)




@constexpr
def make_empty_slice():
empty_slice = slice(None, None, None)
return empty_slice


@constexpr
def get_pos_of_int_index(indexes_types):
"""Get int index positions from the mixed tensors index which contains int, tensor, slice, and ellipsis."""
int_positions = []
for i, ele_type in enumerate(indexes_types):
if ele_type in (mstype.int32, mstype.int64):
int_positions.append(i)
return int_positions


@constexpr
def get_pos_of_int_sequence(indexes_types):
"""Get int and sequence index positions from the mixed tensors index."""
int_positions, sequence_positions = [], []
for i, index_type in enumerate(indexes_types):
if isinstance(index_type, mstype.Int):
int_positions.append(i)
elif isinstance(index_type, (tuple, list)):
sequence_positions.append(i)
return int_positions, sequence_positions


@constexpr
def separate_mixed_tensors_index(indexes_types, op_name):
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
tensor_positions = []
slice_positions = []
ellipsis_position = None
for i, ele_type in enumerate(indexes_types):
if isinstance(ele_type, mstype.tensor_type):
tensor_positions.append(i)
elif isinstance(ele_type, mstype.slice_type):
slice_positions.append(i)
elif isinstance(ele_type, mstype.ellipsis_type):
ellipsis_position = i
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {ele_type}.")

return tensor_positions, slice_positions, ellipsis_position


@constexpr @constexpr
def get_pos_of_indexes_types(indexes_types, op_name): def get_pos_of_indexes_types(indexes_types, op_name):
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index.""" """Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
@@ -805,6 +668,8 @@ def get_pos_of_indexes_types(indexes_types, op_name):
else: else:
raise IndexError(f"For '{op_name}', the index elements only support " raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.") f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.")
if len(ellipsis_positions) > 1:
raise IndexError(f"For '{op_name}, an index can only have a single ellipsis('...')")


return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \ return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \
tensor_positions, sequence_positions tensor_positions, sequence_positions
@@ -906,10 +771,10 @@ def get_stride_info_from_tuple(data_shape, tuple_index):
ellipsis_count = ellipsis_count + 1 ellipsis_count = ellipsis_count + 1
if ellipsis_count > 1: if ellipsis_count > 1:
raise IndexError("An index can have only one ellipsis (...)") raise IndexError("An index can have only one ellipsis (...)")
ellipsis_range_size = data_rank - (tuple_index_len - 1)
ellipsis_range_size = data_rank - tuple_index_len + 1
begin_strides.extend([0] * (ellipsis_range_size)) begin_strides.extend([0] * (ellipsis_range_size))
end_strides.extend( end_strides.extend(
[i for i in data_shape[index_count: index_count + (ellipsis_range_size)]])
[shape for shape in data_shape[index_count: index_count + ellipsis_range_size]])
step_strides.extend([1] * (ellipsis_range_size)) step_strides.extend([1] * (ellipsis_range_size))
index_count = index_count + ellipsis_range_size index_count = index_count + ellipsis_range_size
else: else:


+ 1
- 1
mindspore/ops/composite/multitype_ops/getitem_impl.py View File

@@ -162,7 +162,7 @@ def _tensor_getitem_by_number(data, number_index):




@getitem.register("Tensor", "None") @getitem.register("Tensor", "None")
def _tensor_getitem_by_none(data, index):
def _tensor_getitem_by_none(data, none_index):
""" """
For none indexing , expand data with one dim. For none indexing , expand data with one dim.




+ 2
- 0
mindspore/ops/composite/multitype_ops/setitem_impl.py View File

@@ -132,6 +132,7 @@ def _dict_setitem_with_number(data, key, value):
""" """
return F.dict_setitem(data, key, value) return F.dict_setitem(data, key, value)



@setitem.register("Dictionary", "String", "Tuple") @setitem.register("Dictionary", "String", "Tuple")
def _dict_setitem_with_tuple(data, key, value): def _dict_setitem_with_tuple(data, key, value):
""" """
@@ -147,6 +148,7 @@ def _dict_setitem_with_tuple(data, key, value):
""" """
return F.dict_setitem(data, key, value) return F.dict_setitem(data, key, value)



@setitem.register("Tensor", "Tensor", "Tensor") @setitem.register("Tensor", "Tensor", "Tensor")
def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor): def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
""" """


+ 9
- 3
mindspore/ops/composite/random_ops.py View File

@@ -21,6 +21,7 @@ from .multitype_ops import _constexpr_utils as const_utils
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.seed import _get_graph_seed from ...common.seed import _get_graph_seed



@constexpr @constexpr
def _get_seed(op_seed, kernel_name): def _get_seed(op_seed, kernel_name):
"Get the graph-level seed." "Get the graph-level seed."
@@ -59,14 +60,15 @@ def normal(shape, mean, stddev, seed=None):
""" """
mean_dtype = F.dtype(mean) mean_dtype = F.dtype(mean)
stddev_dtype = F.dtype(stddev) stddev_dtype = F.dtype(stddev)
const_utils.check_valid_type(mean_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
const_utils.check_valid_type(stddev_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
const_utils.check_type_valid(mean_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
const_utils.check_type_valid(stddev_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
seed1, seed2 = _get_seed(seed, "normal") seed1, seed2 = _get_seed(seed, "normal")
stdnormal = P.StandardNormal(seed1, seed2) stdnormal = P.StandardNormal(seed1, seed2)
random_normal = stdnormal(shape) random_normal = stdnormal(shape)
value = random_normal * stddev + mean value = random_normal * stddev + mean
return value return value



def laplace(shape, mean, lambda_param, seed=None): def laplace(shape, mean, lambda_param, seed=None):
r""" r"""
Generates random numbers according to the Laplace random number distribution. Generates random numbers according to the Laplace random number distribution.
@@ -112,6 +114,7 @@ def laplace(shape, mean, lambda_param, seed=None):
value = rnd * lambda_param + mean value = rnd * lambda_param + mean
return value return value



def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32): def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
""" """
Generates random numbers according to the Uniform random number distribution. Generates random numbers according to the Uniform random number distribution.
@@ -159,7 +162,7 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
""" """
minval_dtype = F.dtype(minval) minval_dtype = F.dtype(minval)
maxval_dtype = F.dtype(maxval) maxval_dtype = F.dtype(maxval)
const_utils.check_valid_type(dtype, [mstype.int32, mstype.float32], 'uniform')
const_utils.check_type_valid(dtype, [mstype.int32, mstype.float32], 'uniform')
const_utils.check_tensors_dtype_same(minval_dtype, dtype, "uniform") const_utils.check_tensors_dtype_same(minval_dtype, dtype, "uniform")
const_utils.check_tensors_dtype_same(maxval_dtype, dtype, "uniform") const_utils.check_tensors_dtype_same(maxval_dtype, dtype, "uniform")
seed1, seed2 = _get_seed(seed, "uniform") seed1, seed2 = _get_seed(seed, "uniform")
@@ -172,6 +175,7 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
value = random_uniform * (maxval - minval) + minval value = random_uniform * (maxval - minval) + minval
return value return value



def gamma(shape, alpha, beta, seed=None): def gamma(shape, alpha, beta, seed=None):
""" """
Generates random numbers according to the Gamma random number distribution. Generates random numbers according to the Gamma random number distribution.
@@ -205,6 +209,7 @@ def gamma(shape, alpha, beta, seed=None):
value = random_gamma(shape, alpha, beta) value = random_gamma(shape, alpha, beta)
return value return value



def poisson(shape, mean, seed=None): def poisson(shape, mean, seed=None):
""" """
Generates random numbers according to the Poisson random number distribution. Generates random numbers according to the Poisson random number distribution.
@@ -235,6 +240,7 @@ def poisson(shape, mean, seed=None):
value = random_poisson(shape, mean) value = random_poisson(shape, mean)
return value return value



def multinomial(inputs, num_sample, replacement=True, seed=None): def multinomial(inputs, num_sample, replacement=True, seed=None):
r""" r"""
Returns a tensor sampled from the multinomial probability distribution located in the corresponding Returns a tensor sampled from the multinomial probability distribution located in the corresponding


+ 4
- 3
mindspore/ops/operations/array_ops.py View File

@@ -724,6 +724,7 @@ class Transpose(PrimitiveWithInfer):
out['max_shape'] = tuple(max_vec) out['max_shape'] = tuple(max_vec)
return out return out



class Unique(Primitive): class Unique(Primitive):
""" """
Returns the unique elements of input tensor and also return a tensor containing the index of each value of input Returns the unique elements of input tensor and also return a tensor containing the index of each value of input
@@ -2787,7 +2788,7 @@ class StridedSlice(PrimitiveWithInfer):
if has_ellipsis: if has_ellipsis:
# When there is ellipsis, handle the second half of the ellipsis split. # When there is ellipsis, handle the second half of the ellipsis split.
ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + \ ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + \
len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
ret_shape.extend(x_shape[i:i + ellipsis_occupied_dims]) ret_shape.extend(x_shape[i:i + ellipsis_occupied_dims])
j += 1 j += 1
i += ellipsis_occupied_dims i += ellipsis_occupied_dims
@@ -3144,7 +3145,7 @@ class TensorScatterUpdate(PrimitiveWithInfer):
return x_shape return x_shape


def infer_dtype(self, x_dtype, indices_dtype, value_dtype): def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32, mstype.int64], self.name)
args = {"x": x_dtype, "value": value_dtype} args = {"x": x_dtype, "value": value_dtype}
validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name) validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype return x_dtype
@@ -3983,7 +3984,7 @@ class SpaceToBatchND(PrimitiveWithInfer):
offset = 1 offset = 1
for i in range(len(self.block_shape)): for i in range(len(self.block_shape)):
padded = out_shape[i + offset] + self.paddings[i][0] + \ padded = out_shape[i + offset] + self.paddings[i][0] + \
self.paddings[i][1]
self.paddings[i][1]
if padded % self.block_shape[i] != 0: if padded % self.block_shape[i] != 0:
raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by ' raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
f'block_shape[{i}] {self.block_shape[i]}') f'block_shape[{i}] {self.block_shape[i]}')


Loading…
Cancel
Save