diff --git a/mindspore/common/dtype.py b/mindspore/common/dtype.py index 3289e87d2b..dc047648e1 100644 --- a/mindspore/common/dtype.py +++ b/mindspore/common/dtype.py @@ -38,7 +38,7 @@ __dtype__ = [ "number", "tensor", "string", "type_none", "tensor_type", - "Type" + "Type", "Int" ] __method__ = [ @@ -104,6 +104,7 @@ tuple_type = typing.Tuple index_slices = typing.RowTensorType() sparse_tensor = typing.SparseTensorType() undetermined = typing.UndeterminedType() +Int = typing.Int number_type = (int8, int16, diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index 5667a5b15f..f592eba7b3 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -271,7 +271,7 @@ def tensor_index_by_tuple(data, tuple_index): if len(tuple_index) == 1: return data[tuple_index[0]] indexes_types = hyper_map(F.typeof, tuple_index) - index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_GETITEM) + index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_GETITEM) if index_elements_type == const_utils.NO_TENSOR: return _tensor_index_by_tuple_slice(data, tuple_index) if index_elements_type == const_utils.ALL_TENSOR: @@ -437,13 +437,16 @@ def tensor_setitem_by_tuple_with_number(data, tuple_index, value): data[tuple_index[0]] = value return data indexes_types = hyper_map(F.typeof, tuple_index) - index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) + index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM) if index_elements_type == const_utils.ALL_TENSOR: indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM) else: + int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) + if int_cnt == const_utils.ALL_INT: + tuple_index = const_utils.convert_int_to_slice(tuple_index) indices = _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, const_utils.TENSOR_SETITEM) @@ -498,16 +501,37 @@ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value): if len(tuple_index) == 1: data[tuple_index[0]] = value return data - indexes_types = hyper_map(F.typeof, tuple_index) - index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) + data_shape = data.shape + tuple_index_new = () + for i, index in enumerate(tuple_index): + if isinstance(index, mstype.Int): + if index < -data_shape[i] or index >= data_shape[i]: + const_utils.raise_index_error("The index is out of the data's special dimension range.") + elif index < 0: + tuple_index_new += (tuple_index[i]+data_shape[i],) + else: + tuple_index_new += (tuple_index[i],) + else: + tuple_index_new += (tuple_index[i],) + + indexes_types = hyper_map(F.typeof, tuple_index_new) + index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM) if index_elements_type == const_utils.ALL_TENSOR: indices = _generate_indices_from_tuple_of_tensor(data, - tuple_index, + tuple_index_new, const_utils.TENSOR_SETITEM) else: + int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) + if int_cnt == const_utils.ALL_INT: + tuple_index_new = const_utils.convert_int_to_slice(tuple_index_new) + new_shape = () + for _ in tuple_index_new: + new_shape += (1,) + new_shape += value.shape + value = F.reshape(value, new_shape) indices = _generate_indices_from_tuple_of_mixed_tensors(data, - tuple_index, + tuple_index_new, const_utils.TENSOR_SETITEM) updates = _generate_updates_from_tensor(data, indices, @@ -522,13 +546,16 @@ def tensor_setitem_by_tuple_with_tuple(data, tuple_index, value): data[tuple_index[0]] = value return data indexes_types = hyper_map(F.typeof, tuple_index) - index_elements_type = const_utils.tuple_index_elements_type(indexes_types, const_utils.TENSOR_SETITEM) + index_elements_type = const_utils.tuple_index_tensor_cnt(indexes_types, const_utils.TENSOR_SETITEM) if index_elements_type == const_utils.ALL_TENSOR: indices = _generate_indices_from_tuple_of_tensor(data, tuple_index, const_utils.TENSOR_SETITEM) else: + int_cnt = const_utils.tuple_index_int_cnt(indexes_types, const_utils.TENSOR_SETITEM) + if int_cnt == const_utils.ALL_INT: + tuple_index = const_utils.convert_int_to_slice(tuple_index) indices = _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, const_utils.TENSOR_SETITEM) diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 88ed6db888..757e6559cc 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -28,6 +28,9 @@ ALL_TENSOR = 0 NO_TENSOR = 1 CONTAIN_TENSOR = 2 ALL_SCALAR = 3 +ALL_INT = 4 +NO_INT = 5 +CONTAIN_INT = 6 INT_ = 0 BOOL_ = 1 @@ -72,6 +75,35 @@ def check_ellipsis_shape_size(data_shape, value_shape, data_size, value_size): value_shape, data_shape)) +@constexpr +def restrict_int_index(data_shape, tuple_indexes): + """ + Check the int index of tuple_indexes if value of index is out of the corresponding data shape + and turn the negtive int index to positive int index. + + Inputs: + data_shape: the shape of data. + tuple_indexes(tuple[mstype.int32]): the tuple of index which will be used in setitem or getitem. + + Outputs: + tuple_indexes_new(tuple[mstype.int32]): same purpose with tuple_indexes but only contain positive. + """ + if tuple_indexes is None: + return tuple_indexes + tuple_indexes_new = () + for i, index in enumerate(tuple_indexes): + if isinstance(index, mstype.Int): + if index < -data_shape[i] or index >= data_shape[i]: + const_utils.raise_index_error("The index is out of the data's special dimension range.") + elif index < 0: + tuple_indexes_new += (tuple_indexes[i]+data_shape[i],) + else: + tuple_indexes_new += (tuple_indexes[i],) + else: + tuple_indexes_new += (tuple_indexes[i],) + return tuple_indexes_new + + @constexpr def check_tensor_setitem_index(index, element_type=None): """Checks tuple index type of tensor assignment.""" @@ -276,17 +308,17 @@ def tuple_element_is_int(indexs): @constexpr -def tuple_index_elements_type(types, op_name): - """Judges the type of all elements of the tuple.""" - tensors_number = 0 - for ele in types: - if isinstance(ele, mstype.tensor_type): - tensors_number += 1 - if tensors_number == len(types): - return ALL_TENSOR - if tensors_number == 0: - return NO_TENSOR - return CONTAIN_TENSOR +def tuple_index_tensor_cnt(types, op_name): + """count the tensor type of types which contains the tuple elements' type.""" + tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types) + return ALL_TENSOR if tensor_cnt == len(types) else NO_TENSOR if tensor_cnt == 0 else CONTAIN_TENSOR + + +@constexpr +def tuple_index_int_cnt(types, op_name): + """count the int type of types which contains the tuple elements' type.""" + int_cnt = sum(isinstance(ele, mstype.Int) for ele in types) + return ALL_INT if int_cnt == len(types) else NO_INT if int_cnt == 0 else CONTAIN_INT @constexpr @@ -406,6 +438,12 @@ def compute_new_shape(origin_shape, indexes_shapes_info): return tuple(new_shape) +@constexpr +def convert_int_to_slice(tuple_indexes): + tuple_indexes_new = tuple(slice(i, i+1, 1) for i in tuple_indexes) + return tuple_indexes_new + + @constexpr def convert_ellipsis_to_tensors(slice_number, ellipsis_occupied_dims,