Browse Source

fancy index getitem: None

tags/v1.2.0-rc1
Payne 5 years ago
parent
commit
f56dd425c8
9 changed files with 323 additions and 527 deletions
  1. +1
    -1
      mindspore/nn/layer/basic.py
  2. +1
    -1
      mindspore/ops/composite/array_ops.py
  3. +10
    -6
      mindspore/ops/composite/math_ops.py
  4. +238
    -320
      mindspore/ops/composite/multitype_ops/_compile_utils.py
  5. +57
    -192
      mindspore/ops/composite/multitype_ops/_constexpr_utils.py
  6. +1
    -1
      mindspore/ops/composite/multitype_ops/getitem_impl.py
  7. +2
    -0
      mindspore/ops/composite/multitype_ops/setitem_impl.py
  8. +9
    -3
      mindspore/ops/composite/random_ops.py
  9. +4
    -3
      mindspore/ops/operations/array_ops.py

+ 1
- 1
mindspore/nn/layer/basic.py View File

@@ -84,7 +84,7 @@ class L1Regularizer(Cell):
self.scale = Tensor(scale, dtype=mstype.float32)

def construct(self, weights):
const_utils.check_valid_type(F.dtype(weights), mstype.number_type, 'weights')
const_utils.check_type_valid(F.dtype(weights), mstype.number_type, 'weights')
l1_regularization = self.scale * self.reduce_sum(self.abs(weights))
return l1_regularization



+ 1
- 1
mindspore/ops/composite/array_ops.py View File

@@ -82,7 +82,7 @@ def repeat_elements(x, rep, axis=0):
[3 4 5]
[3 4 5]]
"""
const_utils.check_valid_type(F.dtype(x), mstype.number_type, 'input x')
const_utils.check_type_valid(F.dtype(x), mstype.number_type, 'input x')
rep = _check_positive_int(rep, "rep", "repeat_elements")
axis = _check_is_int(axis, "axis", "repeat_elements")



+ 10
- 6
mindspore/ops/composite/math_ops.py View File

@@ -22,6 +22,8 @@ from mindspore.ops import functional as F
from .. import operations as P

# count_nonzero


@constexpr
def _check_validate_axis(axis, name):
if isinstance(axis, (tuple, list)):
@@ -63,10 +65,10 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
[[3]]
"""

const_utils.check_valid_type(F.dtype(x), mstype.number_type, 'input x')
const_utils.check_type_valid(F.dtype(x), mstype.number_type, 'input x')
axis = _check_validate_axis(axis, "count_nonzero")
keep_dims = _check_validate_keepdims(keep_dims, "count_nonzero")
const_utils.check_valid_type(dtype, mstype.number_type + (mstype.bool_,), 'dtype')
const_utils.check_type_valid(dtype, mstype.number_type + (mstype.bool_,), 'dtype')

not_equal = P.NotEqual()
cast = P.Cast()
@@ -79,6 +81,8 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
return nonzero_num

# tensor dot


@constexpr
def _int_to_tuple_conv(axes):
"""
@@ -97,10 +101,10 @@ def _check_axes(axes):
"""
validator.check_value_type('axes', axes, [int, tuple, list], "tensor dot")
if not isinstance(axes, int):
axes = list(axes) # to avoid immutability issues
axes = list(axes) # to avoid immutability issues
if len(axes) != 2:
raise ValueError("Require two axes inputs, given less")
axes = _int_to_tuple_conv(axes) # convert before length checks
axes = _int_to_tuple_conv(axes) # convert before length checks
if len(axes[0]) != len(axes[1]):
raise ValueError("Axes have to be the same size/length")
if len(axes[0]) != len(set(axes[0])) or len(axes[1]) != len(set(axes[1])):
@@ -113,8 +117,8 @@ def _typecheck_input(x1_type, x2_type):
"""
Check input tensor types to be valid and confirm they are the same type.
"""
const_utils.check_valid_type(x1_type, [mstype.float32, mstype.float16], 'x1')
const_utils.check_valid_type(x2_type, [mstype.float32, mstype.float16], 'x2')
const_utils.check_type_valid(x1_type, [mstype.float32, mstype.float16], 'x1')
const_utils.check_type_valid(x2_type, [mstype.float32, mstype.float16], 'x2')
if x1_type != x2_type:
raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ')



+ 238
- 320
mindspore/ops/composite/multitype_ops/_compile_utils.py View File

@@ -26,6 +26,66 @@ hyper_map = base.HyperMap()
pack = P.Pack(axis=-1)


def _tensor_getitem(self, index):
"""Handle tensor getitem"""
if isinstance(index, Tensor):
return tensor_index_by_tensor(self, index)
if isinstance(index, list):
return tensor_index_by_list(self, index)
if isinstance(index, tuple):
return tensor_index_by_tuple(self, index)
# bool type should be judged before int
if isinstance(index, bool):
return _tensor_index_by_bool(self, index)
if isinstance(index, int):
return _tensor_index_by_integer(self, index)
if isinstance(index, slice):
return tensor_index_by_slice(self, index)
if index is None:
return F.expand_dims(self, 0)
if index is ...:
return self
raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor with int, "
f"list and tuple ,but got {index} with type {type(index)}.")


def _tensor_setitem(self, index, value):
"""Handle tensor getitem"""
if isinstance(index, Tensor):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_tensor_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_tensor_with_tensor(self, index, value)
if isinstance(value, tuple):
return tensor_setitem_by_tensor_with_tuple(self, index, value)
if isinstance(index, tuple):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_tuple_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_tuple_with_tensor(self, index, value)
if isinstance(value, tuple):
return tensor_setitem_by_tuple_with_tuple(self, index, value)
if isinstance(index, int):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_number_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_number_with_tensor(self, index, value)
if isinstance(index, slice):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_slice_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_slice_with_tensor(self, index, value)
if isinstance(index, bool):
return _tensor_index_by_bool(self, index)
if index is ...:
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_ellipsis_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_ellipsis_with_tensor(self, index, value)
raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\
and tensor with int32, got {} with type{}".format(index, type(index)))


def _broadcast(broadcast_shape, x):
"""Broadcast tensor to the required shape."""
if F.shape(x) == broadcast_shape:
@@ -42,15 +102,21 @@ def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x):
return _broadcast(final_shape, F.reshape(x, new_shape))


def _transform_ellipsis_to_slice(tuple_index, data, op_name):
"""transform ellipsis in the slice to several slice"""
def _transform_ellipsis_to_slice(data, tuple_index, op_name):
"""Check if the tuple index len is longer than the data's dims and transform ellipsis in the indices
to several slice"""
data_shape = F.shape(data)
data_rank = len(data_shape)
indexes_types = hyper_map(F.typeof, tuple_index)
slice_positions, ellipsis_positions, _, int_positions, _, tensor_positions, sequence_positions = \
const_utils.get_pos_of_indexes_types(indexes_types, op_name)

ellipsis_occupy_dims = data_rank - (len(slice_positions) + len(int_positions) +
len(tensor_positions) + len(sequence_positions))
ellipsis_cnt = len(ellipsis_positions)
if (ellipsis_cnt == 0 and ellipsis_occupy_dims < 0) or (ellipsis_cnt > 0 and ellipsis_occupy_dims < 1):
const_utils.raise_index_error("For the 'getitem Operator', the data_shape should be no less than the "
"tuple index dims")

tuple_index_new = ()
for i, index in enumerate(tuple_index):
@@ -63,122 +129,172 @@ def _transform_ellipsis_to_slice(tuple_index, data, op_name):
return tuple_index_new


def _expand_data_dims_with_none(data, tuple_index, op_name):
"""expand the data's dim with 'None' in tuple_index"""
indexes_types = hyper_map(F.typeof, tuple_index)
none_positions, tuple_index_without_none = (), ()
for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)):
none_type_tag = const_utils.judge_index_type(index_type, mstype.type_none)
tuple_index_without_none += (const_utils.make_empty_slice(),) if none_type_tag else(index,)
none_positions += (i,) if none_type_tag else ()

for dim in none_positions:
data = F.expand_dims(data, dim)

return data, tuple_index_without_none


def tensor_index_by_slice(data, slice_index):
"""Tensor getitem by a single slice"""
shape = F.shape(data)
if not shape:
const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor"
"cannot be 0.")
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(shape, slice_index)
return F.strided_slice(data, begin_strides, end_strides, step_strides)


def tensor_index_by_number(data, number):
"""Tensor getitem by a Number which may be integer/float/bool value"""
number_type = const_utils.check_number_index_type(number)
if number_type == const_utils.BOOL_:
return _tensor_index_by_bool(data, number)
if number_type == const_utils.INT_:
return _tensor_index_by_integer(data, number)
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.")


def _tensor_index_by_bool(data, bool_value):
"""Tensor getitem by a single bool value"""
if bool_value:
return F.expand_dims(data, 0)
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")


def _tensor_index_by_integer(data, number):
"""Tensor getitem by a single integer number"""
data_shape = F.shape(data)
data_rank = len(data_shape)
if data_rank == 0:
return const_utils.raise_type_error("When tensor is indexed by an integer, the dimension of the tensor "
"cannot be 0.")
transformed_number = const_utils.check_and_transform_int_index(number, data_shape[0], const_utils.TENSOR_GETITEM)
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(data_shape, transformed_number)
shrink_axis_mask = 1
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)


def tensor_index_by_tensor(data, tensor_index):
"""Tensor getitem by a single tensor"""
index_type = F.dtype(tensor_index)
const_utils.check_index_type_valid(index_type, mstype.int_type, const_utils.TENSOR_GETITEM)
tensor_index = F.cast(tensor_index, mstype.int64)
return F.gather(data, tensor_index, 0)


def tensor_index_by_list(data, list_index):
"""Tensor getitem by list of int and bool"""
data_shape = F.shape(data)
const_utils.check_sequence_index_type(list_index, const_utils.TENSOR_GETITEM)
sub_tuple_index = const_utils.transform_sequence_index(list_index, data_shape[0], const_utils.TENSOR_GETITEM)
tensor_index = F.tuple_to_array(sub_tuple_index)
tensor_index = F.cast(tensor_index, mstype.int64)
return F.gather(data, tensor_index, 0)


def tensor_index_by_tuple(data, tuple_index):
"""Tensor getitem by tuple of various types with None"""
op_name = const_utils.TENSOR_GETITEM
if len(tuple_index) == 1:
return data[tuple_index[0]]
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name)
indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
if contain_type == const_utils.ALL_TENSOR:
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
if contain_type == const_utils.ALL_BASIC:
return _tensor_getitem_by_tuple_slice(data, tuple_index)
return _tensor_getitem_by_tuple(data, tuple_index)


def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
"""Tensor getitem by a tuple of tensor."""
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result


def _tensor_getitem_by_tuple_slice(data, tuple_index):
"""Tensor getitem by a tuple of slice"""
data_shape = F.shape(data)
begin_strides, end_strides, step_strides, shrink_axis_mask = \
const_utils.get_stride_info_from_tuple(data_shape, tuple_index)
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)


def _tensor_getitem_by_tuple(data, tuple_index):
"""Tensor getitem by a tuple of mixed tensor."""
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result


def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple of tensor."""
indices = None
check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name)
if check_index_tensor_number:
dtype_tuple = hyper_map(F.dtype, tuple_index)
check_dtypes = const_utils.check_index_tensors_dtype(dtype_tuple, op_name)
if check_dtypes:
shape_tuple = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name)
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
indices = pack(broadcast_tensors)
indexes_types = hyper_map(F.dtype, tuple_index)
const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name)
tensor_index_shape = hyper_map(F.shape, tuple_index)
broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
indices = pack(broadcast_tensors)
indices = F.cast(indices, mstype.int64)
return indices


def _generate_indices_from_tuple(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
data_shape = F.shape(data)
tuple_index_len = len(tuple_index)
tensor_indexes, slice_indexes = [], []
indexes_types = hyper_map(F.typeof, tuple_index)
int_positions, sequence_positions = const_utils.get_pos_of_int_sequence(indexes_types)
slice_positions, _, _, int_positions, _, \
tensor_positions, sequence_positions = const_utils.get_pos_of_indexes_types(indexes_types, op_name)
tuple_index_new = ()
tuple_len = len(tuple_index)

for i in range(tuple_len):
index = tuple_index[i]
shape = data_shape[i]
for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
if i in int_positions:
int_index = const_utils.check_and_transform_int_index(index, shape, op_name)
tensor_index = F.scalar_to_tensor(int_index, mstype.int32)
int_index = const_utils.check_and_transform_int_index(index, dim_size, op_name)
tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
tensor_positions.append(i)
elif i in sequence_positions:
sequence_index = const_utils.transform_sequence_index(index, shape, op_name)
sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name)
tensor_index = F.tuple_to_array(sequence_index)
tensor_index = F.cast(tensor_index, mstype.int64)
tuple_index_new += (tensor_index,)
else:
tensor_indexes.append(tensor_index)
tensor_positions.append(i)
elif i in tensor_positions:
tensor_index = F.cast(index, mstype.int64)
tuple_index_new += (tensor_index,)
tensor_indexes.append(tensor_index)
elif i in slice_positions:
slice_indexes.append(index)
tuple_index_new += (index,)

indexes_types = hyper_map(F.typeof, tuple_index_new)
tensor_positions, slice_positions, ellipsis_position = \
const_utils.separate_mixed_tensors_index(indexes_types, op_name)
tensor_indexes, slice_indexes = [], []
for i in tensor_positions:
tensor_indexes.append(tuple_index_new[i])
for j in slice_positions:
slice_indexes.append(tuple_index_new[j])

tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \
const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape,
indexes_types,
tensor_indexes_shapes,
tensor_indexes_dtypes,
slice_indexes,
op_name)

slice_number = 0
final_index_tensors = []
tuple_index_size = len(tuple_index_new)
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_size):
if i in tensor_positions:
transform_tensor = _transform_indexing_tensor(
broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i])
final_index_tensors.append(transform_tensor)
if i in slice_positions:
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
final_index_tensors.append(slice_tensor)
slice_number += 1
if i == ellipsis_position:
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(
slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name)
for ele in ellipsis_tensors:
final_index_tensors.append(ele)
slice_number += ellipsis_occupied_dims
indices = pack(final_index_tensors)
return indices


def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
"""Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
data_shape = F.shape(data)
indexes_types = hyper_map(F.typeof, tuple_index)
int_positions = const_utils.get_pos_of_int_index(indexes_types)
tuple_index_new = ()
tuple_len = len(tuple_index)
for i in range(tuple_len):
if i in int_positions:
tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] +
data_shape[i], mstype.int32),)
else:
tuple_index_new += (tuple_index[i],)
indexes_types = hyper_map(F.typeof, tuple_index_new)
tensor_positions, slice_positions, ellipsis_position = \
const_utils.separate_mixed_tensors_index(indexes_types, op_name)
tensor_indexes = []
slice_indexes = []
for i in tensor_positions:
tensor_indexes.append(tuple_index_new[i])
for j in slice_positions:
slice_indexes.append(tuple_index_new[j])
tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes)
broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \
const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape,
indexes_types,
tensor_indexes_shapes,
tensor_indexes_dtypes,
slice_indexes,
op_name)
broadcast_shape, final_shape, indexes_shapes_info = const_utils.generate_index_info_from_tuple_of_mixed_tensors(
data_shape, indexes_types, tensor_indexes_shapes, tensor_indexes_dtypes, slice_indexes, op_name)

slice_number = 0
final_index_tensors = []
tuple_index_size = len(tuple_index_new)
index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info)
for i in range(tuple_index_size):
for i in range(tuple_index_len):
if i in tensor_positions:
transform_tensor = _transform_indexing_tensor(
broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i])
@@ -187,12 +303,7 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name)
final_index_tensors.append(slice_tensor)
slice_number += 1
if i == ellipsis_position:
ellipsis_tensors = const_utils.convert_ellipsis_to_tensors(
slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name)
for ele in ellipsis_tensors:
final_index_tensors.append(ele)
slice_number += ellipsis_occupied_dims

indices = pack(final_index_tensors)
return indices

@@ -239,179 +350,8 @@ def _generate_updates_from_tensor(data, index, value, op_type):
return value


def _tensor_getitem(self, index):
"""Handle tensor getitem"""
if isinstance(index, Tensor):
return tensor_index_by_tensor(self, index)
if isinstance(index, tuple):
return tensor_index_by_tuple(self, index)
if isinstance(index, list):
return tensor_index_by_list(self, index)
# bool type should be judged before int
if isinstance(index, bool):
return _tensor_index_by_bool(self, index)
if isinstance(index, int):
return _tensor_index_by_integer(self, index)
if isinstance(index, slice):
return tensor_index_by_slice(self, index)
if index is None:
return F.expand_dims(self, 0)
if index is ...:
return self
raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32, "
f"got {index} with type {type(index)}.")


tensor_operator_registry.register("__getitem__", _tensor_getitem)


def _tensor_getitem_by_tuple_of_tensor(data, tuple_index):
"""Tensor getitem by a tuple of tensor."""
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result


def _tensor_getitem_by_tuple(data, tuple_index):
"""Tensor getitem by a tuple of mixed tensor."""
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result


def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index):
"""Tensor getitem by a tuple of mixed tensor."""
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_GETITEM)
result = F.gather_nd(data, indices)
return result


def tensor_index_by_slice(data, slice_index):
"""Tensor getitem by a single slice"""
shape = F.shape(data)
if not shape:
const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor"
"cannot be 0.")
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(shape, slice_index)
return F.strided_slice(data, begin_strides, end_strides, step_strides)


def _tensor_index_by_integer(data, number):
"""Tensor getitem by a single integer number"""
shape = F.shape(data)
if not shape:
return const_utils.raise_type_error("When tensor is indexed by an integer,"
"the dimension of the tensor cannot be 0.")
if number >= shape[0]:
return const_utils.raise_index_error("index {} is out of bounds for axis 0 with size {}".format(
number, shape[0]))
begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(shape, number)
shrink_axis_mask = 1
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)


def _tensor_index_by_bool(data, bool_value):
"""Tensor getitem by a single bool value"""
if bool_value:
return F.expand_dims(data, 0)
return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")


def tensor_index_by_number(data, number):
"""Tensor getitem by a Number which may be integer/float/bool value"""
number_type = const_utils.check_number_index_type(number)
if number_type == const_utils.BOOL_:
return _tensor_index_by_bool(data, number)
if number_type == const_utils.INT_:
return _tensor_index_by_integer(data, number)
return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.")


def tensor_index_by_tensor(data, tensor_index):
"""Tensor getitem by a single tensor"""
dtype_valid = const_utils.check_index_tensor_dtype(F.dtype(tensor_index),
const_utils.TENSOR_GETITEM)
if dtype_valid:
return F.gather(data, tensor_index, 0)
return const_utils.raise_index_error("For 'tensor getitem', "
"the index tensor data type only support mstype.int32.")


def _tensor_index_by_tuple_slice(data, tuple_index):
"""Tensor getitem by a tuple of slice"""
data_shape = F.shape(data)
if len(tuple_index) > len(data_shape):
const_utils.raise_index_error("When tensor is indexed by a tuple, the length of the tuple cannot "
"be greater than the dimension of the tensor.")
begin_strides, end_strides, step_strides, shrink_axis_mask = \
const_utils.get_stride_info_from_tuple(data_shape, tuple_index)
return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides)


def tensor_index_by_list(data, list_index):
"""Tensor getitem by list of int and bool"""
data_shape = F.shape(data)
const_utils.check_list_index_type(list_index)
list_index = const_utils.transform_list(list_index, data_shape[0])
tensor_index = const_utils.convert_list_to_tensor(list_index)
return F.gather(data, tensor_index, 0)


def tensor_index_by_tuple(data, tuple_index):
"""Tensor getitem by tuple of various types with None"""
if len(tuple_index) == 1:
return data[tuple_index[0]]
tuple_index = _transform_ellipsis_to_slice(tuple_index, data, const_utils.TENSOR_GETITEM)
indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_GETITEM)
if contain_type == const_utils.ALL_TENSOR:
return _tensor_getitem_by_tuple_of_tensor(data, tuple_index)
if contain_type == const_utils.ALL_BASIC:
return _tensor_index_by_tuple_slice(data, tuple_index)
return _tensor_getitem_by_tuple(data, tuple_index)


def _tensor_setitem(self, index, value):
"""Handle tensor getitem"""
if isinstance(index, Tensor):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_tensor_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_tensor_with_tensor(self, index, value)
if isinstance(value, tuple):
return tensor_setitem_by_tensor_with_tuple(self, index, value)
if isinstance(index, tuple):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_tuple_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_tuple_with_tensor(self, index, value)
if isinstance(value, tuple):
return tensor_setitem_by_tuple_with_tuple(self, index, value)
if isinstance(index, int):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_number_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_number_with_tensor(self, index, value)
if isinstance(index, slice):
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_slice_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_slice_with_tensor(self, index, value)
if isinstance(index, bool):
return _tensor_index_by_bool(self, index)
if index is ...:
if isinstance(value, (int, float, bool)):
return tensor_setitem_by_ellipsis_with_number(self, index, value)
if isinstance(value, Tensor):
return tensor_setitem_by_ellipsis_with_tensor(self, index, value)
raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\
and tensor with int32, got {} with type{}".format(index, type(index)))


tensor_operator_registry.register("__setitem__", _tensor_setitem)


@@ -532,24 +472,21 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
if len(tuple_index) == 1:
data[tuple_index[0]] = value
return data
op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name)

indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)

if index_elements_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
if contain_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM)
else:
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if int_cnt == const_utils.ALL_INT:
tuple_index = const_utils.convert_int_to_slice(tuple_index)
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_scalar(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_scalar(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates)


@@ -597,42 +534,26 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
if len(tuple_index) == 1:
data[tuple_index[0]] = value
return data
data_shape = data.shape
tuple_index_new = ()
for i, index in enumerate(tuple_index):
if isinstance(index, mstype.Int):
if index < -data_shape[i] or index >= data_shape[i]:
const_utils.raise_index_error("The index is out of the data's special dimension range.")
elif index < 0:
tuple_index_new += (tuple_index[i]+data_shape[i],)
else:
tuple_index_new += (tuple_index[i],)
else:
tuple_index_new += (tuple_index[i],)
op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name)

indexes_types = hyper_map(F.typeof, tuple_index_new)
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM)
indexes_types = hyper_map(F.typeof, tuple_index)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)

if index_elements_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index_new,
const_utils.TENSOR_SETITEM)
if contain_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM)
else:
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if int_cnt == const_utils.ALL_INT:
tuple_index_new = const_utils.convert_int_to_slice(tuple_index_new)
tuple_index = const_utils.convert_int_to_slice(tuple_index)
new_shape = ()
for _ in tuple_index_new:
for _ in tuple_index:
new_shape += (1,)
new_shape += value.shape
value = F.reshape(value, new_shape)
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index_new,
const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_tensor(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_tensor(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates)


@@ -641,24 +562,21 @@ def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value):
if len(tuple_index) == 1:
data[tuple_index[0]] = value
return data
op_name = const_utils.TENSOR_GETITEM
tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name)

indexes_types = hyper_map(F.typeof, tuple_index)
index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM)
contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM)

if index_elements_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data,
tuple_index,
const_utils.TENSOR_SETITEM)
if contain_type == const_utils.ALL_TENSOR:
indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM)
else:
int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM)
if int_cnt == const_utils.ALL_INT:
tuple_index = const_utils.convert_int_to_slice(tuple_index)
indices = _generate_indices_from_tuple_of_mixed_tensors(data,
tuple_index,
const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_tuple(data,
indices,
value,
const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM)
updates = _generate_updates_from_tuple(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
return P.TensorScatterUpdate()(data, indices, updates)




+ 57
- 192
mindspore/ops/composite/multitype_ops/_constexpr_utils.py View File

@@ -69,16 +69,8 @@ def check_equal(param1, param2, msg="{},{}"):


@constexpr
def split_tuple_index_for_none(tuple_index):
"""return the none_positions and the tuple_index_without_none whose None index is replaced by slice."""
none_positions, tuple_index_without_none = (), ()
for idx, item in enumerate(tuple_index):
if item is None:
none_positions += (idx,)
tuple_index_without_none += (slice(None, None, None),)
else:
tuple_index_without_none += (item,)
return none_positions, tuple_index_without_none
def make_empty_slice():
return slice(None, None, None)


@constexpr
@@ -139,10 +131,31 @@ def check_valid_dim(dim, name):


@constexpr
def check_valid_type(data_type, value_type, name):
if not data_type in value_type:
raise TypeError(
f"For {name}, valid type include {value_type}, {data_type} is invalid")
def judge_index_type(index_type, target_type):
if index_type == target_type or (isinstance(target_type, (list, tuple)) and index_type in target_type):
return True
return False


@constexpr
def check_type_valid(dtype, target_type, op_name):
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
raise TypeError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")


@constexpr
def check_index_type_valid(dtype, target_type, op_name):
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
raise IndexError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")


@constexpr
def check_indexes_types_valid(dtypes, target_type, op_name):
"""Check a tuple of tensor data type."""
for dtype in dtypes:
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
raise IndexError(f"For '{op_name}', the all index tensor data types should be in {target_type}, "
f"but got {dtype}.")


def slice_expand(input_slices, shape):
@@ -156,9 +169,7 @@ def slice_expand(input_slices, shape):
Outputs:
tuple[list], This is expressed as (begins, ends, strides).
"""
begin = []
end = []
strides = []
begin, end, strides = [], [], []
index = 0
slices = None
# Slice or tuple(Slice...)
@@ -269,19 +280,6 @@ def integer_to_indices(index, shape):
return Tensor(value, dtype=mstype.int32)


@constexpr
def tuple_element_is_slice(indexs):
"""Judges tuple element type."""
if not indexs:
raise IndexError("Tensor's index cannot be empty.")
if isinstance(indexs, tuple):
for _, ele in enumerate(indexs):
if not isinstance(ele, slice):
return False
return True
return False


@constexpr
def tuple_element_is_int(indexs):
"""Judges tuple element type."""
@@ -395,8 +393,7 @@ def generate_broadcast_shape(shapes, op_name):
for i, shape in enumerate(shapes):
logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
try:
broadcast_shape = op_utils.get_broadcast_shape(
broadcast_shape, shape, op_name)
broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name)
except ValueError as ex:
raise IndexError(ex)
return tuple(broadcast_shape)
@@ -439,80 +436,17 @@ def compute_new_shape(origin_shape, indexes_shapes_info):


@constexpr
def check_list_index_type(list_index):
def check_sequence_index_type(sequence_index, op_name):
"""check if the item's type of list_index is bool or int"""
if not all([isinstance(index, (int, bool)) for index in list_index]):
raise IndexError(
f"Tensor only support 'integer' or 'boolean' array(list/tuple), but got {type(index)} in array")
if not all([isinstance(index, (int, bool)) for index in sequence_index]):
raise IndexError(f"In the {op_name} operation, only support 'integer' or 'boolean' array(list/tuple), "
f"but got {type(index)} in array")


@constexpr
def transform_list(list_index, shape):
"""transfor list_index from int or bool to int"""
bool_count = len(list(filter(lambda index: isinstance(index, bool), list_index)))
int_count = len(list(filter(lambda index: isinstance(index, int), list_index)))-bool_count
if int_count == 0:
if bool_count == shape:
list_index = list(filter(lambda i: list_index[i], range(bool_count)))
else:
raise IndexError("The boolean array should have the same length with the corresponding dimensiton")
else:
list_index = [int(index) for index in list_index]
for i, index in enumerate(list_index):
if index < -shape or index >= shape:
raise IndexError(f"The index should in the range [-{shape}, {shape-1}] to fit the corresponding dim "
f"length, but get {index}.")
if index < 0:
index += shape
list_index[i] = index
return list_index


@constexpr
def convert_list_to_tensor(list_index):
"""convert the list_index to tensor_index with mstype.int64 dtype"""
return Tensor(list_index, mstype.int64)


@constexpr
def convert_int_to_slice(tuple_indexes):
tuple_indexes_new = tuple(slice(i, i+1, 1) for i in tuple_indexes)
return tuple_indexes_new


@constexpr
def convert_ellipsis_to_tensors(slice_number,
ellipsis_occupied_dims,
final_shape,
indexes_shapes_info,
op_name):
"""Convert an ellipsis to a list of tensor."""
tensor_list = []
dims_dealt_count = 0
while dims_dealt_count < ellipsis_occupied_dims:
shape = []
slice_count = 0
array = None
for ele in indexes_shapes_info:
if isinstance(ele, list):
if slice_count == slice_number:
array = np.array(ele, np.int32)
shape.append(len(ele))
else:
shape.append(1)
slice_count += 1
if isinstance(ele, tuple):
shape.extend([1] * len(ele))
if array is None:
raise ValueError(
f"For '{op_name}', generate tensors from ellipsis failed.")
array = np.reshape(array, shape)
reps = compute_multiples(shape, final_shape)
tensor = Tensor(np.tile(array, reps))
tensor_list.append(tensor)
slice_number += 1
dims_dealt_count += 1
return tensor_list
def convert_int_to_slice(tuple_index):
tuple_index_new = tuple(slice(i, i+1, 1) for i in tuple_index)
return tuple_index_new


@constexpr
@@ -567,7 +501,7 @@ def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_n
f"For '{op_name}', generate tensor from 'slice' failed.")
array = np.reshape(array, shape)
reps = compute_multiples(shape, final_shape)
tensor = Tensor(np.tile(array, reps))
tensor = Tensor(np.tile(array, reps), mstype.int64)
return tensor


@@ -617,21 +551,17 @@ def generate_updates_shape(data_shape, index_shape, op_type):


@constexpr
def check_number_of_index_tensor(data_shape, tuple_len, op_name):
def check_tuple_index_len(data_rank, tuple_index_len, op_name):
"""Check if the number of index tensor exceeds the dimension of the operated tensor."""
if tuple_len <= len(data_shape):
if tuple_index_len <= data_rank:
return True
raise IndexError(f"For '{op_name}', the number {tuple_len} of index tensor "
f"is greater than the dimension {len(data_shape)} of the operated tensor.")
raise IndexError(f"For '{op_name}', the number {tuple_index_len} of tuple_index size"
f"is greater than the dimension {data_rank} of the operated tensor.")


@constexpr
def generate_index_info_from_tuple_of_mixed_tensors(data_shape,
indexes_types,
tensor_indexes_shapes,
tensor_indexes_dtypes,
slice_indexes,
op_name):
def generate_index_info_from_tuple_of_mixed_tensors(data_shape, indexes_types, tensor_indexes_shapes,
tensor_indexes_dtypes, slice_indexes, op_name):
"""
Generate index info which contain broadcast shape, final shape,
indexes shapes info, ellipsis size from a tuple of mixed tensors.
@@ -642,22 +572,14 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape,
if indexes_size > data_rank:
raise IndexError(f"For '{op_name}', the number {indexes_size} of index elements "
f"is greater than the dimension {len(data_shape)} of the operated tensor.")
indexes_info = {}
index_tensors_info = {}
ellipsis_num = 0
ellipsis_occupied_dims = 0
tensor_count = 0
slice_count = 0
for i, ele_type in enumerate(indexes_types):
if ellipsis_num == 0:
pos = i
else:
pos = i + ellipsis_occupied_dims - 1
if isinstance(ele_type, mstype.tensor_type):
indexes_info, index_tensors_info = {}, {}
tensor_count, slice_count = 0, 0
for pos, index_type in enumerate(indexes_types):
if isinstance(index_type, mstype.tensor_type):
indexes_info[pos] = tensor_indexes_shapes[tensor_count]
index_tensors_info[pos] = tensor_indexes_shapes[tensor_count]
tensor_count += 1
elif isinstance(ele_type, mstype.slice_type):
elif isinstance(index_type, mstype.slice_type):
slice_obj = slice(slice_indexes[slice_count].start,
slice_indexes[slice_count].stop,
slice_indexes[slice_count].step)
@@ -669,22 +591,12 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape,
slice_indexes[slice_count].stop,
slice_indexes[slice_count].step))
slice_count += 1
elif isinstance(ele_type, mstype.ellipsis_type):
if ellipsis_num != 0:
raise IndexError(
f"For '{op_name}', the index could only contain one ellipsis.")
ellipsis_occupied_dims = data_rank - indexes_size + 1
for j in range(pos, pos + ellipsis_occupied_dims):
# Use list to represent slicing result.
indexes_info[j] = list(range(data_shape[j]))
ellipsis_num += 1
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {ele_type}.")
broadcast_shape, final_shape, indexes_shapes_info = \
_derive_result_shape_info_from_tuple_of_mixed_tensors(
indexes_info, index_tensors_info, op_name)
return broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims
f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {index_type}.")
broadcast_shape, final_shape, indexes_shapes_info = _derive_result_shape_info_from_tuple_of_mixed_tensors(
indexes_info, index_tensors_info, op_name)
return broadcast_shape, final_shape, indexes_shapes_info


def _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key: list):
@@ -701,8 +613,7 @@ def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_te
index_tensor_info_value = list(index_tensors_info.values())
broadcast_shape = generate_broadcast_shape(
index_tensor_info_value, op_name)
final_shape = []
indexes_shapes_info = []
final_shape, indexes_shapes_info = [], []
mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous(
index_tensor_info_key)
if mixed_tensors_continuous:
@@ -734,54 +645,6 @@ def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_te
return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info)


@constexpr
def make_empty_slice():
empty_slice = slice(None, None, None)
return empty_slice


@constexpr
def get_pos_of_int_index(indexes_types):
"""Get int index positions from the mixed tensors index which contains int, tensor, slice, and ellipsis."""
int_positions = []
for i, ele_type in enumerate(indexes_types):
if ele_type in (mstype.int32, mstype.int64):
int_positions.append(i)
return int_positions


@constexpr
def get_pos_of_int_sequence(indexes_types):
"""Get int and sequence index positions from the mixed tensors index."""
int_positions, sequence_positions = [], []
for i, index_type in enumerate(indexes_types):
if isinstance(index_type, mstype.Int):
int_positions.append(i)
elif isinstance(index_type, (tuple, list)):
sequence_positions.append(i)
return int_positions, sequence_positions


@constexpr
def separate_mixed_tensors_index(indexes_types, op_name):
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
tensor_positions = []
slice_positions = []
ellipsis_position = None
for i, ele_type in enumerate(indexes_types):
if isinstance(ele_type, mstype.tensor_type):
tensor_positions.append(i)
elif isinstance(ele_type, mstype.slice_type):
slice_positions.append(i)
elif isinstance(ele_type, mstype.ellipsis_type):
ellipsis_position = i
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {ele_type}.")

return tensor_positions, slice_positions, ellipsis_position


@constexpr
def get_pos_of_indexes_types(indexes_types, op_name):
"""Separate the position information of tensor and slice and ellipsis from the mixed tensors index."""
@@ -805,6 +668,8 @@ def get_pos_of_indexes_types(indexes_types, op_name):
else:
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.")
if len(ellipsis_positions) > 1:
raise IndexError(f"For '{op_name}, an index can only have a single ellipsis('...')")

return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \
tensor_positions, sequence_positions
@@ -906,10 +771,10 @@ def get_stride_info_from_tuple(data_shape, tuple_index):
ellipsis_count = ellipsis_count + 1
if ellipsis_count > 1:
raise IndexError("An index can have only one ellipsis (...)")
ellipsis_range_size = data_rank - (tuple_index_len - 1)
ellipsis_range_size = data_rank - tuple_index_len + 1
begin_strides.extend([0] * (ellipsis_range_size))
end_strides.extend(
[i for i in data_shape[index_count: index_count + (ellipsis_range_size)]])
[shape for shape in data_shape[index_count: index_count + ellipsis_range_size]])
step_strides.extend([1] * (ellipsis_range_size))
index_count = index_count + ellipsis_range_size
else:


+ 1
- 1
mindspore/ops/composite/multitype_ops/getitem_impl.py View File

@@ -162,7 +162,7 @@ def _tensor_getitem_by_number(data, number_index):


@getitem.register("Tensor", "None")
def _tensor_getitem_by_none(data, index):
def _tensor_getitem_by_none(data, none_index):
"""
For none indexing , expand data with one dim.



+ 2
- 0
mindspore/ops/composite/multitype_ops/setitem_impl.py View File

@@ -132,6 +132,7 @@ def _dict_setitem_with_number(data, key, value):
"""
return F.dict_setitem(data, key, value)


@setitem.register("Dictionary", "String", "Tuple")
def _dict_setitem_with_tuple(data, key, value):
"""
@@ -147,6 +148,7 @@ def _dict_setitem_with_tuple(data, key, value):
"""
return F.dict_setitem(data, key, value)


@setitem.register("Tensor", "Tensor", "Tensor")
def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
"""


+ 9
- 3
mindspore/ops/composite/random_ops.py View File

@@ -21,6 +21,7 @@ from .multitype_ops import _constexpr_utils as const_utils
from ...common import dtype as mstype
from ...common.seed import _get_graph_seed


@constexpr
def _get_seed(op_seed, kernel_name):
"Get the graph-level seed."
@@ -59,14 +60,15 @@ def normal(shape, mean, stddev, seed=None):
"""
mean_dtype = F.dtype(mean)
stddev_dtype = F.dtype(stddev)
const_utils.check_valid_type(mean_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
const_utils.check_valid_type(stddev_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
const_utils.check_type_valid(mean_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
const_utils.check_type_valid(stddev_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal')
seed1, seed2 = _get_seed(seed, "normal")
stdnormal = P.StandardNormal(seed1, seed2)
random_normal = stdnormal(shape)
value = random_normal * stddev + mean
return value


def laplace(shape, mean, lambda_param, seed=None):
r"""
Generates random numbers according to the Laplace random number distribution.
@@ -112,6 +114,7 @@ def laplace(shape, mean, lambda_param, seed=None):
value = rnd * lambda_param + mean
return value


def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
"""
Generates random numbers according to the Uniform random number distribution.
@@ -159,7 +162,7 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
"""
minval_dtype = F.dtype(minval)
maxval_dtype = F.dtype(maxval)
const_utils.check_valid_type(dtype, [mstype.int32, mstype.float32], 'uniform')
const_utils.check_type_valid(dtype, [mstype.int32, mstype.float32], 'uniform')
const_utils.check_tensors_dtype_same(minval_dtype, dtype, "uniform")
const_utils.check_tensors_dtype_same(maxval_dtype, dtype, "uniform")
seed1, seed2 = _get_seed(seed, "uniform")
@@ -172,6 +175,7 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
value = random_uniform * (maxval - minval) + minval
return value


def gamma(shape, alpha, beta, seed=None):
"""
Generates random numbers according to the Gamma random number distribution.
@@ -205,6 +209,7 @@ def gamma(shape, alpha, beta, seed=None):
value = random_gamma(shape, alpha, beta)
return value


def poisson(shape, mean, seed=None):
"""
Generates random numbers according to the Poisson random number distribution.
@@ -235,6 +240,7 @@ def poisson(shape, mean, seed=None):
value = random_poisson(shape, mean)
return value


def multinomial(inputs, num_sample, replacement=True, seed=None):
r"""
Returns a tensor sampled from the multinomial probability distribution located in the corresponding


+ 4
- 3
mindspore/ops/operations/array_ops.py View File

@@ -724,6 +724,7 @@ class Transpose(PrimitiveWithInfer):
out['max_shape'] = tuple(max_vec)
return out


class Unique(Primitive):
"""
Returns the unique elements of input tensor and also return a tensor containing the index of each value of input
@@ -2787,7 +2788,7 @@ class StridedSlice(PrimitiveWithInfer):
if has_ellipsis:
# When there is ellipsis, handle the second half of the ellipsis split.
ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + \
len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
ret_shape.extend(x_shape[i:i + ellipsis_occupied_dims])
j += 1
i += ellipsis_occupied_dims
@@ -3144,7 +3145,7 @@ class TensorScatterUpdate(PrimitiveWithInfer):
return x_shape

def infer_dtype(self, x_dtype, indices_dtype, value_dtype):
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name)
validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32, mstype.int64], self.name)
args = {"x": x_dtype, "value": value_dtype}
validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name)
return x_dtype
@@ -3983,7 +3984,7 @@ class SpaceToBatchND(PrimitiveWithInfer):
offset = 1
for i in range(len(self.block_shape)):
padded = out_shape[i + offset] + self.paddings[i][0] + \
self.paddings[i][1]
self.paddings[i][1]
if padded % self.block_shape[i] != 0:
raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by '
f'block_shape[{i}] {self.block_shape[i]}')


Loading…
Cancel
Save