diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index bea7acde9e..e6ae679b6f 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -84,7 +84,7 @@ class L1Regularizer(Cell): self.scale = Tensor(scale, dtype=mstype.float32) def construct(self, weights): - const_utils.check_valid_type(F.dtype(weights), mstype.number_type, 'weights') + const_utils.check_type_valid(F.dtype(weights), mstype.number_type, 'weights') l1_regularization = self.scale * self.reduce_sum(self.abs(weights)) return l1_regularization diff --git a/mindspore/ops/composite/array_ops.py b/mindspore/ops/composite/array_ops.py index 894abc7777..a64d5faac0 100644 --- a/mindspore/ops/composite/array_ops.py +++ b/mindspore/ops/composite/array_ops.py @@ -82,7 +82,7 @@ def repeat_elements(x, rep, axis=0): [3 4 5] [3 4 5]] """ - const_utils.check_valid_type(F.dtype(x), mstype.number_type, 'input x') + const_utils.check_type_valid(F.dtype(x), mstype.number_type, 'input x') rep = _check_positive_int(rep, "rep", "repeat_elements") axis = _check_is_int(axis, "axis", "repeat_elements") diff --git a/mindspore/ops/composite/math_ops.py b/mindspore/ops/composite/math_ops.py index 1173ae12bb..f5586311d2 100644 --- a/mindspore/ops/composite/math_ops.py +++ b/mindspore/ops/composite/math_ops.py @@ -22,6 +22,8 @@ from mindspore.ops import functional as F from .. import operations as P # count_nonzero + + @constexpr def _check_validate_axis(axis, name): if isinstance(axis, (tuple, list)): @@ -63,10 +65,10 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32): [[3]] """ - const_utils.check_valid_type(F.dtype(x), mstype.number_type, 'input x') + const_utils.check_type_valid(F.dtype(x), mstype.number_type, 'input x') axis = _check_validate_axis(axis, "count_nonzero") keep_dims = _check_validate_keepdims(keep_dims, "count_nonzero") - const_utils.check_valid_type(dtype, mstype.number_type + (mstype.bool_,), 'dtype') + const_utils.check_type_valid(dtype, mstype.number_type + (mstype.bool_,), 'dtype') not_equal = P.NotEqual() cast = P.Cast() @@ -79,6 +81,8 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32): return nonzero_num # tensor dot + + @constexpr def _int_to_tuple_conv(axes): """ @@ -97,10 +101,10 @@ def _check_axes(axes): """ validator.check_value_type('axes', axes, [int, tuple, list], "tensor dot") if not isinstance(axes, int): - axes = list(axes) # to avoid immutability issues + axes = list(axes) # to avoid immutability issues if len(axes) != 2: raise ValueError("Require two axes inputs, given less") - axes = _int_to_tuple_conv(axes) # convert before length checks + axes = _int_to_tuple_conv(axes) # convert before length checks if len(axes[0]) != len(axes[1]): raise ValueError("Axes have to be the same size/length") if len(axes[0]) != len(set(axes[0])) or len(axes[1]) != len(set(axes[1])): @@ -113,8 +117,8 @@ def _typecheck_input(x1_type, x2_type): """ Check input tensor types to be valid and confirm they are the same type. """ - const_utils.check_valid_type(x1_type, [mstype.float32, mstype.float16], 'x1') - const_utils.check_valid_type(x2_type, [mstype.float32, mstype.float16], 'x2') + const_utils.check_type_valid(x1_type, [mstype.float32, mstype.float16], 'x1') + const_utils.check_type_valid(x2_type, [mstype.float32, mstype.float16], 'x2') if x1_type != x2_type: raise TypeError(f'Both Inputs must be the same Type. x1 is \'{x1_type}\' and x2 is \'{x2_type}\' ') diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index a87773265f..051777c85d 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -26,6 +26,66 @@ hyper_map = base.HyperMap() pack = P.Pack(axis=-1) +def _tensor_getitem(self, index): + """Handle tensor getitem""" + if isinstance(index, Tensor): + return tensor_index_by_tensor(self, index) + if isinstance(index, list): + return tensor_index_by_list(self, index) + if isinstance(index, tuple): + return tensor_index_by_tuple(self, index) + # bool type should be judged before int + if isinstance(index, bool): + return _tensor_index_by_bool(self, index) + if isinstance(index, int): + return _tensor_index_by_integer(self, index) + if isinstance(index, slice): + return tensor_index_by_slice(self, index) + if index is None: + return F.expand_dims(self, 0) + if index is ...: + return self + raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool, tensor with int, " + f"list and tuple ,but got {index} with type {type(index)}.") + + +def _tensor_setitem(self, index, value): + """Handle tensor getitem""" + if isinstance(index, Tensor): + if isinstance(value, (int, float, bool)): + return tensor_setitem_by_tensor_with_number(self, index, value) + if isinstance(value, Tensor): + return tensor_setitem_by_tensor_with_tensor(self, index, value) + if isinstance(value, tuple): + return tensor_setitem_by_tensor_with_tuple(self, index, value) + if isinstance(index, tuple): + if isinstance(value, (int, float, bool)): + return tensor_setitem_by_tuple_with_number(self, index, value) + if isinstance(value, Tensor): + return tensor_setitem_by_tuple_with_tensor(self, index, value) + if isinstance(value, tuple): + return tensor_setitem_by_tuple_with_tuple(self, index, value) + if isinstance(index, int): + if isinstance(value, (int, float, bool)): + return tensor_setitem_by_number_with_number(self, index, value) + if isinstance(value, Tensor): + return tensor_setitem_by_number_with_tensor(self, index, value) + if isinstance(index, slice): + if isinstance(value, (int, float, bool)): + return tensor_setitem_by_slice_with_number(self, index, value) + if isinstance(value, Tensor): + return tensor_setitem_by_slice_with_tensor(self, index, value) + if isinstance(index, bool): + return _tensor_index_by_bool(self, index) + if index is ...: + if isinstance(value, (int, float, bool)): + return tensor_setitem_by_ellipsis_with_number(self, index, value) + if isinstance(value, Tensor): + return tensor_setitem_by_ellipsis_with_tensor(self, index, value) + raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\ + and tensor with int32, got {} with type{}".format(index, type(index))) + + def _broadcast(broadcast_shape, x): """Broadcast tensor to the required shape.""" if F.shape(x) == broadcast_shape: @@ -42,15 +102,21 @@ def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, x): return _broadcast(final_shape, F.reshape(x, new_shape)) -def _transform_ellipsis_to_slice(tuple_index, data, op_name): - """transform ellipsis in the slice to several slice""" +def _transform_ellipsis_to_slice(data, tuple_index, op_name): + """Check if the tuple index len is longer than the data's dims and transform ellipsis in the indices + to several slice""" data_shape = F.shape(data) data_rank = len(data_shape) indexes_types = hyper_map(F.typeof, tuple_index) slice_positions, ellipsis_positions, _, int_positions, _, tensor_positions, sequence_positions = \ const_utils.get_pos_of_indexes_types(indexes_types, op_name) + ellipsis_occupy_dims = data_rank - (len(slice_positions) + len(int_positions) + len(tensor_positions) + len(sequence_positions)) + ellipsis_cnt = len(ellipsis_positions) + if (ellipsis_cnt == 0 and ellipsis_occupy_dims < 0) or (ellipsis_cnt > 0 and ellipsis_occupy_dims < 1): + const_utils.raise_index_error("For the 'getitem Operator', the data_shape should be no less than the " + "tuple index dims") tuple_index_new = () for i, index in enumerate(tuple_index): @@ -63,122 +129,172 @@ def _transform_ellipsis_to_slice(tuple_index, data, op_name): return tuple_index_new +def _expand_data_dims_with_none(data, tuple_index, op_name): + """expand the data's dim with 'None' in tuple_index""" + indexes_types = hyper_map(F.typeof, tuple_index) + none_positions, tuple_index_without_none = (), () + for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)): + none_type_tag = const_utils.judge_index_type(index_type, mstype.type_none) + tuple_index_without_none += (const_utils.make_empty_slice(),) if none_type_tag else(index,) + none_positions += (i,) if none_type_tag else () + + for dim in none_positions: + data = F.expand_dims(data, dim) + + return data, tuple_index_without_none + + +def tensor_index_by_slice(data, slice_index): + """Tensor getitem by a single slice""" + shape = F.shape(data) + if not shape: + const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor" + "cannot be 0.") + begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(shape, slice_index) + return F.strided_slice(data, begin_strides, end_strides, step_strides) + + +def tensor_index_by_number(data, number): + """Tensor getitem by a Number which may be integer/float/bool value""" + number_type = const_utils.check_number_index_type(number) + if number_type == const_utils.BOOL_: + return _tensor_index_by_bool(data, number) + if number_type == const_utils.INT_: + return _tensor_index_by_integer(data, number) + return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.") + + +def _tensor_index_by_bool(data, bool_value): + """Tensor getitem by a single bool value""" + if bool_value: + return F.expand_dims(data, 0) + return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.") + + +def _tensor_index_by_integer(data, number): + """Tensor getitem by a single integer number""" + data_shape = F.shape(data) + data_rank = len(data_shape) + if data_rank == 0: + return const_utils.raise_type_error("When tensor is indexed by an integer, the dimension of the tensor " + "cannot be 0.") + transformed_number = const_utils.check_and_transform_int_index(number, data_shape[0], const_utils.TENSOR_GETITEM) + begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(data_shape, transformed_number) + shrink_axis_mask = 1 + return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) + + +def tensor_index_by_tensor(data, tensor_index): + """Tensor getitem by a single tensor""" + index_type = F.dtype(tensor_index) + const_utils.check_index_type_valid(index_type, mstype.int_type, const_utils.TENSOR_GETITEM) + tensor_index = F.cast(tensor_index, mstype.int64) + return F.gather(data, tensor_index, 0) + + +def tensor_index_by_list(data, list_index): + """Tensor getitem by list of int and bool""" + data_shape = F.shape(data) + const_utils.check_sequence_index_type(list_index, const_utils.TENSOR_GETITEM) + sub_tuple_index = const_utils.transform_sequence_index(list_index, data_shape[0], const_utils.TENSOR_GETITEM) + tensor_index = F.tuple_to_array(sub_tuple_index) + tensor_index = F.cast(tensor_index, mstype.int64) + return F.gather(data, tensor_index, 0) + + +def tensor_index_by_tuple(data, tuple_index): + """Tensor getitem by tuple of various types with None""" + op_name = const_utils.TENSOR_GETITEM + if len(tuple_index) == 1: + return data[tuple_index[0]] + tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) + data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name) + indexes_types = hyper_map(F.typeof, tuple_index) + contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name) + if contain_type == const_utils.ALL_TENSOR: + return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) + if contain_type == const_utils.ALL_BASIC: + return _tensor_getitem_by_tuple_slice(data, tuple_index) + return _tensor_getitem_by_tuple(data, tuple_index) + + +def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): + """Tensor getitem by a tuple of tensor.""" + indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_GETITEM) + result = F.gather_nd(data, indices) + return result + + +def _tensor_getitem_by_tuple_slice(data, tuple_index): + """Tensor getitem by a tuple of slice""" + data_shape = F.shape(data) + begin_strides, end_strides, step_strides, shrink_axis_mask = \ + const_utils.get_stride_info_from_tuple(data_shape, tuple_index) + return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) + + +def _tensor_getitem_by_tuple(data, tuple_index): + """Tensor getitem by a tuple of mixed tensor.""" + indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_GETITEM) + result = F.gather_nd(data, indices) + return result + + def _generate_indices_from_tuple_of_tensor(data, tuple_index, op_name): """Generate an indices tensor from a tuple of tensor.""" indices = None - check_index_tensor_number = const_utils.check_number_of_index_tensor(F.shape(data), len(tuple_index), op_name) - if check_index_tensor_number: - dtype_tuple = hyper_map(F.dtype, tuple_index) - check_dtypes = const_utils.check_index_tensors_dtype(dtype_tuple, op_name) - if check_dtypes: - shape_tuple = hyper_map(F.shape, tuple_index) - broadcast_shape = const_utils.generate_broadcast_shape(shape_tuple, op_name) - broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index) - indices = pack(broadcast_tensors) + indexes_types = hyper_map(F.dtype, tuple_index) + const_utils.check_indexes_types_valid(indexes_types, mstype.int_type, op_name) + tensor_index_shape = hyper_map(F.shape, tuple_index) + broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name) + broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index) + indices = pack(broadcast_tensors) + indices = F.cast(indices, mstype.int64) return indices def _generate_indices_from_tuple(data, tuple_index, op_name): """Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor.""" data_shape = F.shape(data) + tuple_index_len = len(tuple_index) + tensor_indexes, slice_indexes = [], [] indexes_types = hyper_map(F.typeof, tuple_index) - int_positions, sequence_positions = const_utils.get_pos_of_int_sequence(indexes_types) + slice_positions, _, _, int_positions, _, \ + tensor_positions, sequence_positions = const_utils.get_pos_of_indexes_types(indexes_types, op_name) tuple_index_new = () - tuple_len = len(tuple_index) - for i in range(tuple_len): - index = tuple_index[i] - shape = data_shape[i] + for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)): if i in int_positions: - int_index = const_utils.check_and_transform_int_index(index, shape, op_name) - tensor_index = F.scalar_to_tensor(int_index, mstype.int32) + int_index = const_utils.check_and_transform_int_index(index, dim_size, op_name) + tensor_index = F.scalar_to_tensor(int_index, mstype.int64) tuple_index_new += (tensor_index,) + tensor_indexes.append(tensor_index) + tensor_positions.append(i) elif i in sequence_positions: - sequence_index = const_utils.transform_sequence_index(index, shape, op_name) + sequence_index = const_utils.transform_sequence_index(index, dim_size, op_name) tensor_index = F.tuple_to_array(sequence_index) + tensor_index = F.cast(tensor_index, mstype.int64) tuple_index_new += (tensor_index,) - else: + tensor_indexes.append(tensor_index) + tensor_positions.append(i) + elif i in tensor_positions: + tensor_index = F.cast(index, mstype.int64) + tuple_index_new += (tensor_index,) + tensor_indexes.append(tensor_index) + elif i in slice_positions: + slice_indexes.append(index) tuple_index_new += (index,) - indexes_types = hyper_map(F.typeof, tuple_index_new) - tensor_positions, slice_positions, ellipsis_position = \ - const_utils.separate_mixed_tensors_index(indexes_types, op_name) - tensor_indexes, slice_indexes = [], [] - for i in tensor_positions: - tensor_indexes.append(tuple_index_new[i]) - for j in slice_positions: - slice_indexes.append(tuple_index_new[j]) - tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) - broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \ - const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape, - indexes_types, - tensor_indexes_shapes, - tensor_indexes_dtypes, - slice_indexes, - op_name) - - slice_number = 0 - final_index_tensors = [] - tuple_index_size = len(tuple_index_new) - index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) - for i in range(tuple_index_size): - if i in tensor_positions: - transform_tensor = _transform_indexing_tensor( - broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i]) - final_index_tensors.append(transform_tensor) - if i in slice_positions: - slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name) - final_index_tensors.append(slice_tensor) - slice_number += 1 - if i == ellipsis_position: - ellipsis_tensors = const_utils.convert_ellipsis_to_tensors( - slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name) - for ele in ellipsis_tensors: - final_index_tensors.append(ele) - slice_number += ellipsis_occupied_dims - indices = pack(final_index_tensors) - return indices - - -def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): - """Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor.""" - data_shape = F.shape(data) - indexes_types = hyper_map(F.typeof, tuple_index) - int_positions = const_utils.get_pos_of_int_index(indexes_types) - tuple_index_new = () - tuple_len = len(tuple_index) - for i in range(tuple_len): - if i in int_positions: - tuple_index_new += (F.scalar_to_tensor(tuple_index[i] if tuple_index[i] >= 0 else tuple_index[i] + - data_shape[i], mstype.int32),) - else: - tuple_index_new += (tuple_index[i],) indexes_types = hyper_map(F.typeof, tuple_index_new) - tensor_positions, slice_positions, ellipsis_position = \ - const_utils.separate_mixed_tensors_index(indexes_types, op_name) - tensor_indexes = [] - slice_indexes = [] - for i in tensor_positions: - tensor_indexes.append(tuple_index_new[i]) - for j in slice_positions: - slice_indexes.append(tuple_index_new[j]) - tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes) - tensor_indexes_dtypes = hyper_map(F.dtype, tensor_indexes) - broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims = \ - const_utils.generate_index_info_from_tuple_of_mixed_tensors(data_shape, - indexes_types, - tensor_indexes_shapes, - tensor_indexes_dtypes, - slice_indexes, - op_name) + broadcast_shape, final_shape, indexes_shapes_info = const_utils.generate_index_info_from_tuple_of_mixed_tensors( + data_shape, indexes_types, tensor_indexes_shapes, tensor_indexes_dtypes, slice_indexes, op_name) slice_number = 0 final_index_tensors = [] - tuple_index_size = len(tuple_index_new) index_tensor_new_shape = const_utils.compute_new_shape(broadcast_shape, indexes_shapes_info) - for i in range(tuple_index_size): + for i in range(tuple_index_len): if i in tensor_positions: transform_tensor = _transform_indexing_tensor( broadcast_shape, final_shape, index_tensor_new_shape, tuple_index_new[i]) @@ -187,12 +303,7 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name): slice_tensor = const_utils.convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name) final_index_tensors.append(slice_tensor) slice_number += 1 - if i == ellipsis_position: - ellipsis_tensors = const_utils.convert_ellipsis_to_tensors( - slice_number, ellipsis_occupied_dims, final_shape, indexes_shapes_info, op_name) - for ele in ellipsis_tensors: - final_index_tensors.append(ele) - slice_number += ellipsis_occupied_dims + indices = pack(final_index_tensors) return indices @@ -239,179 +350,8 @@ def _generate_updates_from_tensor(data, index, value, op_type): return value -def _tensor_getitem(self, index): - """Handle tensor getitem""" - if isinstance(index, Tensor): - return tensor_index_by_tensor(self, index) - if isinstance(index, tuple): - return tensor_index_by_tuple(self, index) - if isinstance(index, list): - return tensor_index_by_list(self, index) - # bool type should be judged before int - if isinstance(index, bool): - return _tensor_index_by_bool(self, index) - if isinstance(index, int): - return _tensor_index_by_integer(self, index) - if isinstance(index, slice): - return tensor_index_by_slice(self, index) - if index is None: - return F.expand_dims(self, 0) - if index is ...: - return self - raise IndexError(f"Only support integers, slices(`:`), ellipsis(`...`), None, bool and tensor with int32, " - f"got {index} with type {type(index)}.") - - tensor_operator_registry.register("__getitem__", _tensor_getitem) - -def _tensor_getitem_by_tuple_of_tensor(data, tuple_index): - """Tensor getitem by a tuple of tensor.""" - indices = _generate_indices_from_tuple_of_tensor(data, - tuple_index, - const_utils.TENSOR_GETITEM) - result = F.gather_nd(data, indices) - return result - - -def _tensor_getitem_by_tuple(data, tuple_index): - """Tensor getitem by a tuple of mixed tensor.""" - indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_GETITEM) - result = F.gather_nd(data, indices) - return result - - -def _tensor_getitem_by_tuple_of_mixed_tensors(data, tuple_index): - """Tensor getitem by a tuple of mixed tensor.""" - indices = _generate_indices_from_tuple_of_mixed_tensors(data, - tuple_index, - const_utils.TENSOR_GETITEM) - result = F.gather_nd(data, indices) - return result - - -def tensor_index_by_slice(data, slice_index): - """Tensor getitem by a single slice""" - shape = F.shape(data) - if not shape: - const_utils.raise_index_error("When tensor is indexed by a slice, the dimension of the tensor" - "cannot be 0.") - begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_slice(shape, slice_index) - return F.strided_slice(data, begin_strides, end_strides, step_strides) - - -def _tensor_index_by_integer(data, number): - """Tensor getitem by a single integer number""" - shape = F.shape(data) - if not shape: - return const_utils.raise_type_error("When tensor is indexed by an integer," - "the dimension of the tensor cannot be 0.") - if number >= shape[0]: - return const_utils.raise_index_error("index {} is out of bounds for axis 0 with size {}".format( - number, shape[0])) - begin_strides, end_strides, step_strides = const_utils.get_stride_info_from_integer(shape, number) - shrink_axis_mask = 1 - return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) - - -def _tensor_index_by_bool(data, bool_value): - """Tensor getitem by a single bool value""" - if bool_value: - return F.expand_dims(data, 0) - return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.") - - -def tensor_index_by_number(data, number): - """Tensor getitem by a Number which may be integer/float/bool value""" - number_type = const_utils.check_number_index_type(number) - if number_type == const_utils.BOOL_: - return _tensor_index_by_bool(data, number) - if number_type == const_utils.INT_: - return _tensor_index_by_integer(data, number) - return const_utils.raise_index_error("Only support integers, slices(`:`), ellipsis(`...`), None and bool.") - - -def tensor_index_by_tensor(data, tensor_index): - """Tensor getitem by a single tensor""" - dtype_valid = const_utils.check_index_tensor_dtype(F.dtype(tensor_index), - const_utils.TENSOR_GETITEM) - if dtype_valid: - return F.gather(data, tensor_index, 0) - return const_utils.raise_index_error("For 'tensor getitem', " - "the index tensor data type only support mstype.int32.") - - -def _tensor_index_by_tuple_slice(data, tuple_index): - """Tensor getitem by a tuple of slice""" - data_shape = F.shape(data) - if len(tuple_index) > len(data_shape): - const_utils.raise_index_error("When tensor is indexed by a tuple, the length of the tuple cannot " - "be greater than the dimension of the tensor.") - begin_strides, end_strides, step_strides, shrink_axis_mask = \ - const_utils.get_stride_info_from_tuple(data_shape, tuple_index) - return P.StridedSlice(0, 0, 0, 0, shrink_axis_mask)(data, begin_strides, end_strides, step_strides) - - -def tensor_index_by_list(data, list_index): - """Tensor getitem by list of int and bool""" - data_shape = F.shape(data) - const_utils.check_list_index_type(list_index) - list_index = const_utils.transform_list(list_index, data_shape[0]) - tensor_index = const_utils.convert_list_to_tensor(list_index) - return F.gather(data, tensor_index, 0) - - -def tensor_index_by_tuple(data, tuple_index): - """Tensor getitem by tuple of various types with None""" - if len(tuple_index) == 1: - return data[tuple_index[0]] - tuple_index = _transform_ellipsis_to_slice(tuple_index, data, const_utils.TENSOR_GETITEM) - indexes_types = hyper_map(F.typeof, tuple_index) - contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_GETITEM) - if contain_type == const_utils.ALL_TENSOR: - return _tensor_getitem_by_tuple_of_tensor(data, tuple_index) - if contain_type == const_utils.ALL_BASIC: - return _tensor_index_by_tuple_slice(data, tuple_index) - return _tensor_getitem_by_tuple(data, tuple_index) - - -def _tensor_setitem(self, index, value): - """Handle tensor getitem""" - if isinstance(index, Tensor): - if isinstance(value, (int, float, bool)): - return tensor_setitem_by_tensor_with_number(self, index, value) - if isinstance(value, Tensor): - return tensor_setitem_by_tensor_with_tensor(self, index, value) - if isinstance(value, tuple): - return tensor_setitem_by_tensor_with_tuple(self, index, value) - if isinstance(index, tuple): - if isinstance(value, (int, float, bool)): - return tensor_setitem_by_tuple_with_number(self, index, value) - if isinstance(value, Tensor): - return tensor_setitem_by_tuple_with_tensor(self, index, value) - if isinstance(value, tuple): - return tensor_setitem_by_tuple_with_tuple(self, index, value) - if isinstance(index, int): - if isinstance(value, (int, float, bool)): - return tensor_setitem_by_number_with_number(self, index, value) - if isinstance(value, Tensor): - return tensor_setitem_by_number_with_tensor(self, index, value) - if isinstance(index, slice): - if isinstance(value, (int, float, bool)): - return tensor_setitem_by_slice_with_number(self, index, value) - if isinstance(value, Tensor): - return tensor_setitem_by_slice_with_tensor(self, index, value) - if isinstance(index, bool): - return _tensor_index_by_bool(self, index) - if index is ...: - if isinstance(value, (int, float, bool)): - return tensor_setitem_by_ellipsis_with_number(self, index, value) - if isinstance(value, Tensor): - return tensor_setitem_by_ellipsis_with_tensor(self, index, value) - raise IndexError("Tensor setitem index only support integers, slices(`:`), ellipsis(`...`), None, bool\ - and tensor with int32, got {} with type{}".format(index, type(index))) - - tensor_operator_registry.register("__setitem__", _tensor_setitem) @@ -532,24 +472,21 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value): if len(tuple_index) == 1: data[tuple_index[0]] = value return data + op_name = const_utils.TENSOR_GETITEM + tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) + data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name) + indexes_types = hyper_map(F.typeof, tuple_index) - index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM) + contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) - if index_elements_type == const_utils.ALL_TENSOR: - indices = _generate_indices_from_tuple_of_tensor(data, - tuple_index, - const_utils.TENSOR_SETITEM) + if contain_type == const_utils.ALL_TENSOR: + indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM) else: int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) if int_cnt == const_utils.ALL_INT: tuple_index = const_utils.convert_int_to_slice(tuple_index) - indices = _generate_indices_from_tuple_of_mixed_tensors(data, - tuple_index, - const_utils.TENSOR_SETITEM) - updates = _generate_updates_from_scalar(data, - indices, - value, - const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) + indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) + updates = _generate_updates_from_scalar(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) return P.TensorScatterUpdate()(data, indices, updates) @@ -597,42 +534,26 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): if len(tuple_index) == 1: data[tuple_index[0]] = value return data - data_shape = data.shape - tuple_index_new = () - for i, index in enumerate(tuple_index): - if isinstance(index, mstype.Int): - if index < -data_shape[i] or index >= data_shape[i]: - const_utils.raise_index_error("The index is out of the data's special dimension range.") - elif index < 0: - tuple_index_new += (tuple_index[i]+data_shape[i],) - else: - tuple_index_new += (tuple_index[i],) - else: - tuple_index_new += (tuple_index[i],) + op_name = const_utils.TENSOR_GETITEM + tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) + data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name) - indexes_types = hyper_map(F.typeof, tuple_index_new) - index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM) + indexes_types = hyper_map(F.typeof, tuple_index) + contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) - if index_elements_type == const_utils.ALL_TENSOR: - indices = _generate_indices_from_tuple_of_tensor(data, - tuple_index_new, - const_utils.TENSOR_SETITEM) + if contain_type == const_utils.ALL_TENSOR: + indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM) else: int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) if int_cnt == const_utils.ALL_INT: - tuple_index_new = const_utils.convert_int_to_slice(tuple_index_new) + tuple_index = const_utils.convert_int_to_slice(tuple_index) new_shape = () - for _ in tuple_index_new: + for _ in tuple_index: new_shape += (1,) new_shape += value.shape value = F.reshape(value, new_shape) - indices = _generate_indices_from_tuple_of_mixed_tensors(data, - tuple_index_new, - const_utils.TENSOR_SETITEM) - updates = _generate_updates_from_tensor(data, - indices, - value, - const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) + indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) + updates = _generate_updates_from_tensor(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) return P.TensorScatterUpdate()(data, indices, updates) @@ -641,24 +562,21 @@ def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): if len(tuple_index) == 1: data[tuple_index[0]] = value return data + op_name = const_utils.TENSOR_GETITEM + tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name) + data, tuple_index = _expand_data_dims_with_none(data, tuple_index, op_name) + indexes_types = hyper_map(F.typeof, tuple_index) - index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM) + contain_type = const_utils.tuple_index_type_cnt(indexes_types, const_utils.TENSOR_SETITEM) - if index_elements_type == const_utils.ALL_TENSOR: - indices = _generate_indices_from_tuple_of_tensor(data, - tuple_index, - const_utils.TENSOR_SETITEM) + if contain_type == const_utils.ALL_TENSOR: + indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM) else: int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) if int_cnt == const_utils.ALL_INT: tuple_index = const_utils.convert_int_to_slice(tuple_index) - indices = _generate_indices_from_tuple_of_mixed_tensors(data, - tuple_index, - const_utils.TENSOR_SETITEM) - updates = _generate_updates_from_tuple(data, - indices, - value, - const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) + indices = _generate_indices_from_tuple(data, tuple_index, const_utils.TENSOR_SETITEM) + updates = _generate_updates_from_tuple(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR) return P.TensorScatterUpdate()(data, indices, updates) diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 3972f94f7d..57a1d4cd78 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -69,16 +69,8 @@ def check_equal(param1, param2, msg="{},{}"): @constexpr -def split_tuple_index_for_none(tuple_index): - """return the none_positions and the tuple_index_without_none whose None index is replaced by slice.""" - none_positions, tuple_index_without_none = (), () - for idx, item in enumerate(tuple_index): - if item is None: - none_positions += (idx,) - tuple_index_without_none += (slice(None, None, None),) - else: - tuple_index_without_none += (item,) - return none_positions, tuple_index_without_none +def make_empty_slice(): + return slice(None, None, None) @constexpr @@ -139,10 +131,31 @@ def check_valid_dim(dim, name): @constexpr -def check_valid_type(data_type, value_type, name): - if not data_type in value_type: - raise TypeError( - f"For {name}, valid type include {value_type}, {data_type} is invalid") +def judge_index_type(index_type, target_type): + if index_type == target_type or (isinstance(target_type, (list, tuple)) and index_type in target_type): + return True + return False + + +@constexpr +def check_type_valid(dtype, target_type, op_name): + if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): + raise TypeError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.") + + +@constexpr +def check_index_type_valid(dtype, target_type, op_name): + if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): + raise IndexError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.") + + +@constexpr +def check_indexes_types_valid(dtypes, target_type, op_name): + """Check a tuple of tensor data type.""" + for dtype in dtypes: + if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): + raise IndexError(f"For '{op_name}', the all index tensor data types should be in {target_type}, " + f"but got {dtype}.") def slice_expand(input_slices, shape): @@ -156,9 +169,7 @@ def slice_expand(input_slices, shape): Outputs: tuple[list], This is expressed as (begins, ends, strides). """ - begin = [] - end = [] - strides = [] + begin, end, strides = [], [], [] index = 0 slices = None # Slice or tuple(Slice...) @@ -269,19 +280,6 @@ def integer_to_indices(index, shape): return Tensor(value, dtype=mstype.int32) -@constexpr -def tuple_element_is_slice(indexs): - """Judges tuple element type.""" - if not indexs: - raise IndexError("Tensor's index cannot be empty.") - if isinstance(indexs, tuple): - for _, ele in enumerate(indexs): - if not isinstance(ele, slice): - return False - return True - return False - - @constexpr def tuple_element_is_int(indexs): """Judges tuple element type.""" @@ -395,8 +393,7 @@ def generate_broadcast_shape(shapes, op_name): for i, shape in enumerate(shapes): logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") try: - broadcast_shape = op_utils.get_broadcast_shape( - broadcast_shape, shape, op_name) + broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name) except ValueError as ex: raise IndexError(ex) return tuple(broadcast_shape) @@ -439,80 +436,17 @@ def compute_new_shape(origin_shape, indexes_shapes_info): @constexpr -def check_list_index_type(list_index): +def check_sequence_index_type(sequence_index, op_name): """check if the item's type of list_index is bool or int""" - if not all([isinstance(index, (int, bool)) for index in list_index]): - raise IndexError( - f"Tensor only support 'integer' or 'boolean' array(list/tuple), but got {type(index)} in array") + if not all([isinstance(index, (int, bool)) for index in sequence_index]): + raise IndexError(f"In the {op_name} operation, only support 'integer' or 'boolean' array(list/tuple), " + f"but got {type(index)} in array") @constexpr -def transform_list(list_index, shape): - """transfor list_index from int or bool to int""" - bool_count = len(list(filter(lambda index: isinstance(index, bool), list_index))) - int_count = len(list(filter(lambda index: isinstance(index, int), list_index)))-bool_count - if int_count == 0: - if bool_count == shape: - list_index = list(filter(lambda i: list_index[i], range(bool_count))) - else: - raise IndexError("The boolean array should have the same length with the corresponding dimensiton") - else: - list_index = [int(index) for index in list_index] - for i, index in enumerate(list_index): - if index < -shape or index >= shape: - raise IndexError(f"The index should in the range [-{shape}, {shape-1}] to fit the corresponding dim " - f"length, but get {index}.") - if index < 0: - index += shape - list_index[i] = index - return list_index - - -@constexpr -def convert_list_to_tensor(list_index): - """convert the list_index to tensor_index with mstype.int64 dtype""" - return Tensor(list_index, mstype.int64) - - -@constexpr -def convert_int_to_slice(tuple_indexes): - tuple_indexes_new = tuple(slice(i, i+1, 1) for i in tuple_indexes) - return tuple_indexes_new - - -@constexpr -def convert_ellipsis_to_tensors(slice_number, - ellipsis_occupied_dims, - final_shape, - indexes_shapes_info, - op_name): - """Convert an ellipsis to a list of tensor.""" - tensor_list = [] - dims_dealt_count = 0 - while dims_dealt_count < ellipsis_occupied_dims: - shape = [] - slice_count = 0 - array = None - for ele in indexes_shapes_info: - if isinstance(ele, list): - if slice_count == slice_number: - array = np.array(ele, np.int32) - shape.append(len(ele)) - else: - shape.append(1) - slice_count += 1 - if isinstance(ele, tuple): - shape.extend([1] * len(ele)) - if array is None: - raise ValueError( - f"For '{op_name}', generate tensors from ellipsis failed.") - array = np.reshape(array, shape) - reps = compute_multiples(shape, final_shape) - tensor = Tensor(np.tile(array, reps)) - tensor_list.append(tensor) - slice_number += 1 - dims_dealt_count += 1 - return tensor_list +def convert_int_to_slice(tuple_index): + tuple_index_new = tuple(slice(i, i+1, 1) for i in tuple_index) + return tuple_index_new @constexpr @@ -567,7 +501,7 @@ def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_n f"For '{op_name}', generate tensor from 'slice' failed.") array = np.reshape(array, shape) reps = compute_multiples(shape, final_shape) - tensor = Tensor(np.tile(array, reps)) + tensor = Tensor(np.tile(array, reps), mstype.int64) return tensor @@ -617,21 +551,17 @@ def generate_updates_shape(data_shape, index_shape, op_type): @constexpr -def check_number_of_index_tensor(data_shape, tuple_len, op_name): +def check_tuple_index_len(data_rank, tuple_index_len, op_name): """Check if the number of index tensor exceeds the dimension of the operated tensor.""" - if tuple_len <= len(data_shape): + if tuple_index_len <= data_rank: return True - raise IndexError(f"For '{op_name}', the number {tuple_len} of index tensor " - f"is greater than the dimension {len(data_shape)} of the operated tensor.") + raise IndexError(f"For '{op_name}', the number {tuple_index_len} of tuple_index size" + f"is greater than the dimension {data_rank} of the operated tensor.") @constexpr -def generate_index_info_from_tuple_of_mixed_tensors(data_shape, - indexes_types, - tensor_indexes_shapes, - tensor_indexes_dtypes, - slice_indexes, - op_name): +def generate_index_info_from_tuple_of_mixed_tensors(data_shape, indexes_types, tensor_indexes_shapes, + tensor_indexes_dtypes, slice_indexes, op_name): """ Generate index info which contain broadcast shape, final shape, indexes shapes info, ellipsis size from a tuple of mixed tensors. @@ -642,22 +572,14 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape, if indexes_size > data_rank: raise IndexError(f"For '{op_name}', the number {indexes_size} of index elements " f"is greater than the dimension {len(data_shape)} of the operated tensor.") - indexes_info = {} - index_tensors_info = {} - ellipsis_num = 0 - ellipsis_occupied_dims = 0 - tensor_count = 0 - slice_count = 0 - for i, ele_type in enumerate(indexes_types): - if ellipsis_num == 0: - pos = i - else: - pos = i + ellipsis_occupied_dims - 1 - if isinstance(ele_type, mstype.tensor_type): + indexes_info, index_tensors_info = {}, {} + tensor_count, slice_count = 0, 0 + for pos, index_type in enumerate(indexes_types): + if isinstance(index_type, mstype.tensor_type): indexes_info[pos] = tensor_indexes_shapes[tensor_count] index_tensors_info[pos] = tensor_indexes_shapes[tensor_count] tensor_count += 1 - elif isinstance(ele_type, mstype.slice_type): + elif isinstance(index_type, mstype.slice_type): slice_obj = slice(slice_indexes[slice_count].start, slice_indexes[slice_count].stop, slice_indexes[slice_count].step) @@ -669,22 +591,12 @@ def generate_index_info_from_tuple_of_mixed_tensors(data_shape, slice_indexes[slice_count].stop, slice_indexes[slice_count].step)) slice_count += 1 - elif isinstance(ele_type, mstype.ellipsis_type): - if ellipsis_num != 0: - raise IndexError( - f"For '{op_name}', the index could only contain one ellipsis.") - ellipsis_occupied_dims = data_rank - indexes_size + 1 - for j in range(pos, pos + ellipsis_occupied_dims): - # Use list to represent slicing result. - indexes_info[j] = list(range(data_shape[j])) - ellipsis_num += 1 else: raise IndexError(f"For '{op_name}', the index elements only support " - f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {ele_type}.") - broadcast_shape, final_shape, indexes_shapes_info = \ - _derive_result_shape_info_from_tuple_of_mixed_tensors( - indexes_info, index_tensors_info, op_name) - return broadcast_shape, final_shape, indexes_shapes_info, ellipsis_occupied_dims + f"'Tensor', 'int', 'Slice', 'Ellipsis', but got {index_type}.") + broadcast_shape, final_shape, indexes_shapes_info = _derive_result_shape_info_from_tuple_of_mixed_tensors( + indexes_info, index_tensors_info, op_name) + return broadcast_shape, final_shape, indexes_shapes_info def _judge_tuple_of_mixed_tensors_continuous(index_tensor_info_key: list): @@ -701,8 +613,7 @@ def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_te index_tensor_info_value = list(index_tensors_info.values()) broadcast_shape = generate_broadcast_shape( index_tensor_info_value, op_name) - final_shape = [] - indexes_shapes_info = [] + final_shape, indexes_shapes_info = [], [] mixed_tensors_continuous = _judge_tuple_of_mixed_tensors_continuous( index_tensor_info_key) if mixed_tensors_continuous: @@ -734,54 +645,6 @@ def _derive_result_shape_info_from_tuple_of_mixed_tensors(indexes_info, index_te return broadcast_shape, tuple(final_shape), tuple(indexes_shapes_info) -@constexpr -def make_empty_slice(): - empty_slice = slice(None, None, None) - return empty_slice - - -@constexpr -def get_pos_of_int_index(indexes_types): - """Get int index positions from the mixed tensors index which contains int, tensor, slice, and ellipsis.""" - int_positions = [] - for i, ele_type in enumerate(indexes_types): - if ele_type in (mstype.int32, mstype.int64): - int_positions.append(i) - return int_positions - - -@constexpr -def get_pos_of_int_sequence(indexes_types): - """Get int and sequence index positions from the mixed tensors index.""" - int_positions, sequence_positions = [], [] - for i, index_type in enumerate(indexes_types): - if isinstance(index_type, mstype.Int): - int_positions.append(i) - elif isinstance(index_type, (tuple, list)): - sequence_positions.append(i) - return int_positions, sequence_positions - - -@constexpr -def separate_mixed_tensors_index(indexes_types, op_name): - """Separate the position information of tensor and slice and ellipsis from the mixed tensors index.""" - tensor_positions = [] - slice_positions = [] - ellipsis_position = None - for i, ele_type in enumerate(indexes_types): - if isinstance(ele_type, mstype.tensor_type): - tensor_positions.append(i) - elif isinstance(ele_type, mstype.slice_type): - slice_positions.append(i) - elif isinstance(ele_type, mstype.ellipsis_type): - ellipsis_position = i - else: - raise IndexError(f"For '{op_name}', the index elements only support " - f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {ele_type}.") - - return tensor_positions, slice_positions, ellipsis_position - - @constexpr def get_pos_of_indexes_types(indexes_types, op_name): """Separate the position information of tensor and slice and ellipsis from the mixed tensors index.""" @@ -805,6 +668,8 @@ def get_pos_of_indexes_types(indexes_types, op_name): else: raise IndexError(f"For '{op_name}', the index elements only support " f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.") + if len(ellipsis_positions) > 1: + raise IndexError(f"For '{op_name}, an index can only have a single ellipsis('...')") return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \ tensor_positions, sequence_positions @@ -906,10 +771,10 @@ def get_stride_info_from_tuple(data_shape, tuple_index): ellipsis_count = ellipsis_count + 1 if ellipsis_count > 1: raise IndexError("An index can have only one ellipsis (...)") - ellipsis_range_size = data_rank - (tuple_index_len - 1) + ellipsis_range_size = data_rank - tuple_index_len + 1 begin_strides.extend([0] * (ellipsis_range_size)) end_strides.extend( - [i for i in data_shape[index_count: index_count + (ellipsis_range_size)]]) + [shape for shape in data_shape[index_count: index_count + ellipsis_range_size]]) step_strides.extend([1] * (ellipsis_range_size)) index_count = index_count + ellipsis_range_size else: diff --git a/mindspore/ops/composite/multitype_ops/getitem_impl.py b/mindspore/ops/composite/multitype_ops/getitem_impl.py index 8ec376162a..91cfbde791 100644 --- a/mindspore/ops/composite/multitype_ops/getitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/getitem_impl.py @@ -162,7 +162,7 @@ def _tensor_getitem_by_number(data, number_index): @getitem.register("Tensor", "None") -def _tensor_getitem_by_none(data, index): +def _tensor_getitem_by_none(data, none_index): """ For none indexing , expand data with one dim. diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py index b18bdcf4e8..6861f27b0d 100644 --- a/mindspore/ops/composite/multitype_ops/setitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -132,6 +132,7 @@ def _dict_setitem_with_number(data, key, value): """ return F.dict_setitem(data, key, value) + @setitem.register("Dictionary", "String", "Tuple") def _dict_setitem_with_tuple(data, key, value): """ @@ -147,6 +148,7 @@ def _dict_setitem_with_tuple(data, key, value): """ return F.dict_setitem(data, key, value) + @setitem.register("Tensor", "Tensor", "Tensor") def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor): """ diff --git a/mindspore/ops/composite/random_ops.py b/mindspore/ops/composite/random_ops.py index bea9ee456d..4d050dfb43 100644 --- a/mindspore/ops/composite/random_ops.py +++ b/mindspore/ops/composite/random_ops.py @@ -21,6 +21,7 @@ from .multitype_ops import _constexpr_utils as const_utils from ...common import dtype as mstype from ...common.seed import _get_graph_seed + @constexpr def _get_seed(op_seed, kernel_name): "Get the graph-level seed." @@ -59,14 +60,15 @@ def normal(shape, mean, stddev, seed=None): """ mean_dtype = F.dtype(mean) stddev_dtype = F.dtype(stddev) - const_utils.check_valid_type(mean_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal') - const_utils.check_valid_type(stddev_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal') + const_utils.check_type_valid(mean_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal') + const_utils.check_type_valid(stddev_dtype, mstype.int_type + (mstype.float16, mstype.float32), 'normal') seed1, seed2 = _get_seed(seed, "normal") stdnormal = P.StandardNormal(seed1, seed2) random_normal = stdnormal(shape) value = random_normal * stddev + mean return value + def laplace(shape, mean, lambda_param, seed=None): r""" Generates random numbers according to the Laplace random number distribution. @@ -112,6 +114,7 @@ def laplace(shape, mean, lambda_param, seed=None): value = rnd * lambda_param + mean return value + def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32): """ Generates random numbers according to the Uniform random number distribution. @@ -159,7 +162,7 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32): """ minval_dtype = F.dtype(minval) maxval_dtype = F.dtype(maxval) - const_utils.check_valid_type(dtype, [mstype.int32, mstype.float32], 'uniform') + const_utils.check_type_valid(dtype, [mstype.int32, mstype.float32], 'uniform') const_utils.check_tensors_dtype_same(minval_dtype, dtype, "uniform") const_utils.check_tensors_dtype_same(maxval_dtype, dtype, "uniform") seed1, seed2 = _get_seed(seed, "uniform") @@ -172,6 +175,7 @@ def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32): value = random_uniform * (maxval - minval) + minval return value + def gamma(shape, alpha, beta, seed=None): """ Generates random numbers according to the Gamma random number distribution. @@ -205,6 +209,7 @@ def gamma(shape, alpha, beta, seed=None): value = random_gamma(shape, alpha, beta) return value + def poisson(shape, mean, seed=None): """ Generates random numbers according to the Poisson random number distribution. @@ -235,6 +240,7 @@ def poisson(shape, mean, seed=None): value = random_poisson(shape, mean) return value + def multinomial(inputs, num_sample, replacement=True, seed=None): r""" Returns a tensor sampled from the multinomial probability distribution located in the corresponding diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index be56f49468..3332050579 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -724,6 +724,7 @@ class Transpose(PrimitiveWithInfer): out['max_shape'] = tuple(max_vec) return out + class Unique(Primitive): """ Returns the unique elements of input tensor and also return a tensor containing the index of each value of input @@ -2787,7 +2788,7 @@ class StridedSlice(PrimitiveWithInfer): if has_ellipsis: # When there is ellipsis, handle the second half of the ellipsis split. ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + \ - len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len]))) + len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len]))) ret_shape.extend(x_shape[i:i + ellipsis_occupied_dims]) j += 1 i += ellipsis_occupied_dims @@ -3144,7 +3145,7 @@ class TensorScatterUpdate(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype, indices_dtype, value_dtype): - validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) + validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32, mstype.int64], self.name) args = {"x": x_dtype, "value": value_dtype} validator.check_tensors_dtypes_same_and_valid(args, (mstype.bool_,) + mstype.number_type, self.name) return x_dtype @@ -3983,7 +3984,7 @@ class SpaceToBatchND(PrimitiveWithInfer): offset = 1 for i in range(len(self.block_shape)): padded = out_shape[i + offset] + self.paddings[i][0] + \ - self.paddings[i][1] + self.paddings[i][1] if padded % self.block_shape[i] != 0: raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by ' f'block_shape[{i}] {self.block_shape[i]}')