| @@ -147,13 +147,15 @@ def judge_index_type(index_type, target_type): | |||||
| @constexpr | @constexpr | ||||
| def check_type_valid(dtype, target_type, op_name): | def check_type_valid(dtype, target_type, op_name): | ||||
| if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): | if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): | ||||
| raise TypeError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.") | |||||
| raise TypeError( | |||||
| f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.") | |||||
| @constexpr | @constexpr | ||||
| def check_index_type_valid(dtype, target_type, op_name): | def check_index_type_valid(dtype, target_type, op_name): | ||||
| if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): | if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): | ||||
| raise IndexError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.") | |||||
| raise IndexError( | |||||
| f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.") | |||||
| @constexpr | @constexpr | ||||
| @@ -189,7 +191,8 @@ def get_pos_of_indexes_types(indexes_types, op_name): | |||||
| raise IndexError(f"For '{op_name}', the index elements only support " | raise IndexError(f"For '{op_name}', the index elements only support " | ||||
| f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.") | f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.") | ||||
| if len(ellipsis_positions) > 1: | if len(ellipsis_positions) > 1: | ||||
| raise IndexError(f"For '{op_name}, an index can only have a single ellipsis('...')") | |||||
| raise IndexError( | |||||
| f"For '{op_name}, an index can only have a single ellipsis('...')") | |||||
| return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \ | return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \ | ||||
| tensor_positions, sequence_positions | tensor_positions, sequence_positions | ||||
| @@ -260,7 +263,7 @@ def ellipsis2slice(input_, shape): | |||||
| return tuple(result) | return tuple(result) | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def slice2indices(input_slices, shape): | def slice2indices(input_slices, shape): | ||||
| """ | """ | ||||
| Converts slice to indices. | Converts slice to indices. | ||||
| @@ -285,7 +288,7 @@ def slice2indices(input_slices, shape): | |||||
| return ravel | return ravel | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def check_indices(indices_size, index): | def check_indices(indices_size, index): | ||||
| """Checks indices whether is empty.""" | """Checks indices whether is empty.""" | ||||
| if indices_size < 1: | if indices_size < 1: | ||||
| @@ -294,7 +297,7 @@ def check_indices(indices_size, index): | |||||
| return indices_size | return indices_size | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def check_indices_value_size(indices_size, value_size): | def check_indices_value_size(indices_size, value_size): | ||||
| """Checks if the sizes are already matched.""" | """Checks if the sizes are already matched.""" | ||||
| if value_size < 1: | if value_size < 1: | ||||
| @@ -307,7 +310,7 @@ def check_indices_value_size(indices_size, value_size): | |||||
| return value_size | return value_size | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def integer_to_indices(index, shape): | def integer_to_indices(index, shape): | ||||
| """Converts int or tuple[int] to indices.""" | """Converts int or tuple[int] to indices.""" | ||||
| size = reduce(lambda x, y: x * y, shape) | size = reduce(lambda x, y: x * y, shape) | ||||
| @@ -317,7 +320,7 @@ def integer_to_indices(index, shape): | |||||
| return Tensor(value, dtype=mstype.int32) | return Tensor(value, dtype=mstype.int32) | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def tuple_element_is_int(indexs): | def tuple_element_is_int(indexs): | ||||
| """Judges tuple element type.""" | """Judges tuple element type.""" | ||||
| if not indexs: | if not indexs: | ||||
| @@ -330,18 +333,19 @@ def tuple_element_is_int(indexs): | |||||
| return False | return False | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def tuple_index_int_cnt(types, op_name): | def tuple_index_int_cnt(types, op_name): | ||||
| """count the int type of types which contains the tuple elements' type.""" | """count the int type of types which contains the tuple elements' type.""" | ||||
| int_cnt = sum(isinstance(ele, mstype.Int) for ele in types) | int_cnt = sum(isinstance(ele, mstype.Int) for ele in types) | ||||
| return ALL_INT if int_cnt == len(types) else NO_INT if int_cnt == 0 else CONTAIN_INT | return ALL_INT if int_cnt == len(types) else NO_INT if int_cnt == 0 else CONTAIN_INT | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def tuple_index_type_cnt(types, op_name): | def tuple_index_type_cnt(types, op_name): | ||||
| """count the tensor type of types which contains the tuple elements' type.""" | """count the tensor type of types which contains the tuple elements' type.""" | ||||
| tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types) | tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types) | ||||
| basic_cnt = sum(isinstance(ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types) | |||||
| basic_cnt = sum(isinstance( | |||||
| ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types) | |||||
| if tensor_cnt == len(types): | if tensor_cnt == len(types): | ||||
| return ALL_TENSOR | return ALL_TENSOR | ||||
| if basic_cnt == len(types): | if basic_cnt == len(types): | ||||
| @@ -349,7 +353,7 @@ def tuple_index_type_cnt(types, op_name): | |||||
| return MIXED | return MIXED | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def check_value_elements(data_dtype, types): | def check_value_elements(data_dtype, types): | ||||
| """Judges the type of all elements of the tuple.""" | """Judges the type of all elements of the tuple.""" | ||||
| tensors_number = 0 | tensors_number = 0 | ||||
| @@ -377,10 +381,10 @@ def check_value_elements(data_dtype, types): | |||||
| # TODO to del | # TODO to del | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def get_index_tensor_dtype(dtype): | def get_index_tensor_dtype(dtype): | ||||
| """Check a tuple of tensor data type.""" | """Check a tuple of tensor data type.""" | ||||
| if dtype == mstype.int32: | |||||
| if dtype in mstype.int_type: | |||||
| return INT_ | return INT_ | ||||
| if dtype == mstype.bool_: | if dtype == mstype.bool_: | ||||
| return BOOL_ | return BOOL_ | ||||
| @@ -389,7 +393,7 @@ def get_index_tensor_dtype(dtype): | |||||
| # TODO to del | # TODO to del | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def check_index_tensors_dtype(indexes_types, op_name): | def check_index_tensors_dtype(indexes_types, op_name): | ||||
| """Check a tuple of tensor data type.""" | """Check a tuple of tensor data type.""" | ||||
| for index_type in indexes_types: | for index_type in indexes_types: | ||||
| @@ -400,7 +404,7 @@ def check_index_tensors_dtype(indexes_types, op_name): | |||||
| # TODO to del | # TODO to del | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def check_index_tensor_dtype(index_type, op_name): | def check_index_tensor_dtype(index_type, op_name): | ||||
| """Check a tensor data type.""" | """Check a tensor data type.""" | ||||
| if index_type in (mstype.int32, mstype.int64): | if index_type in (mstype.int32, mstype.int64): | ||||
| @@ -410,7 +414,7 @@ def check_index_tensor_dtype(index_type, op_name): | |||||
| # TODO to del | # TODO to del | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def check_tensors_dtype_same(data_dtype, value_dtype, op_name): | def check_tensors_dtype_same(data_dtype, value_dtype, op_name): | ||||
| """Check tensors data type same.""" | """Check tensors data type same.""" | ||||
| if value_dtype == data_dtype: | if value_dtype == data_dtype: | ||||
| @@ -419,7 +423,7 @@ def check_tensors_dtype_same(data_dtype, value_dtype, op_name): | |||||
| f"is not consistent with assigned tensor data type {data_dtype}.") | f"is not consistent with assigned tensor data type {data_dtype}.") | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def generate_broadcast_shape(shapes, op_name): | def generate_broadcast_shape(shapes, op_name): | ||||
| """Generate broadcast shape for a tuple of shape.""" | """Generate broadcast shape for a tuple of shape.""" | ||||
| if not shapes: | if not shapes: | ||||
| @@ -428,13 +432,14 @@ def generate_broadcast_shape(shapes, op_name): | |||||
| for i, shape in enumerate(shapes): | for i, shape in enumerate(shapes): | ||||
| logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") | logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") | ||||
| try: | try: | ||||
| broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name) | |||||
| broadcast_shape = op_utils.get_broadcast_shape( | |||||
| broadcast_shape, shape, op_name) | |||||
| except ValueError as ex: | except ValueError as ex: | ||||
| raise IndexError(ex) | raise IndexError(ex) | ||||
| return tuple(broadcast_shape) | return tuple(broadcast_shape) | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def check_two_shapes_need_broadcast(shape_x, shape_y): | def check_two_shapes_need_broadcast(shape_x, shape_y): | ||||
| """Check two shapes need broadcast.""" | """Check two shapes need broadcast.""" | ||||
| error = ValueError(f"For 'tensor setitem with tensor', the value tensor shape " | error = ValueError(f"For 'tensor setitem with tensor', the value tensor shape " | ||||
| @@ -451,14 +456,14 @@ def check_two_shapes_need_broadcast(shape_x, shape_y): | |||||
| return True | return True | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def compute_multiples(origin_shape, broadcast_shape): | def compute_multiples(origin_shape, broadcast_shape): | ||||
| """Compute multiples between origin shape with broadcast shape.""" | """Compute multiples between origin shape with broadcast shape.""" | ||||
| len_gap = len(broadcast_shape) - len(origin_shape) | len_gap = len(broadcast_shape) - len(origin_shape) | ||||
| return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape)) | return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape)) | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def compute_new_shape(origin_shape, indexes_shapes_info): | def compute_new_shape(origin_shape, indexes_shapes_info): | ||||
| """Compute new shape between origin shape with final shape.""" | """Compute new shape between origin shape with final shape.""" | ||||
| new_shape = [] | new_shape = [] | ||||
| @@ -470,21 +475,22 @@ def compute_new_shape(origin_shape, indexes_shapes_info): | |||||
| return tuple(new_shape) | return tuple(new_shape) | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def check_sequence_index_type(sequence_index, op_name): | def check_sequence_index_type(sequence_index, op_name): | ||||
| """check if the item's type of list_index is bool or int""" | """check if the item's type of list_index is bool or int""" | ||||
| if not all([isinstance(index, (int, bool)) for index in sequence_index]): | |||||
| raise IndexError(f"In the {op_name} operation, only support 'integer' or 'boolean' array(list/tuple), " | |||||
| f"but got {type(index)} in array") | |||||
| for index in sequence_index: | |||||
| if not isinstance(index, int): | |||||
| raise IndexError(f"In the {op_name} operation, only support 'inter' or 'boolean' array(list/tuple), " | |||||
| f"but got {type(index)} in array.") | |||||
| @ constexpr | |||||
| @constexpr | |||||
| def convert_int_to_slice(tuple_index): | def convert_int_to_slice(tuple_index): | ||||
| tuple_index_new = tuple(slice(i, i+1, 1) for i in tuple_index) | tuple_index_new = tuple(slice(i, i+1, 1) for i in tuple_index) | ||||
| return tuple_index_new | return tuple_index_new | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def check_and_transform_int_index(index, shape, op_name): | def check_and_transform_int_index(index, shape, op_name): | ||||
| if index < -shape or index >= shape: | if index < -shape or index >= shape: | ||||
| raise IndexError(f"In the \"{op_name}\", the index should in the range [-{shape}, {shape-1}] to fit " | raise IndexError(f"In the \"{op_name}\", the index should in the range [-{shape}, {shape-1}] to fit " | ||||
| @@ -494,16 +500,20 @@ def check_and_transform_int_index(index, shape, op_name): | |||||
| return index | return index | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def transform_sequence_index(sequence_index, shape, op_name): | def transform_sequence_index(sequence_index, shape, op_name): | ||||
| """transform list or tuple with integer and boolean to tuple with integer index""" | """transform list or tuple with integer and boolean to tuple with integer index""" | ||||
| bool_count = len(list(filter(lambda index: isinstance(index, bool), sequence_index))) | |||||
| int_count = len(list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count | |||||
| bool_count = len( | |||||
| list(filter(lambda index: isinstance(index, bool), sequence_index))) | |||||
| int_count = len( | |||||
| list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count | |||||
| if int_count == 0: | if int_count == 0: | ||||
| if bool_count == shape: | if bool_count == shape: | ||||
| list_index = list(filter(lambda i: sequence_index[i], range(bool_count))) | |||||
| list_index = list( | |||||
| filter(lambda i: sequence_index[i], range(bool_count))) | |||||
| else: | else: | ||||
| raise IndexError("The boolean array should have the same length with the corresponding dimensiton") | |||||
| raise IndexError( | |||||
| "The boolean array should have the same length with the corresponding dimensiton") | |||||
| else: | else: | ||||
| list_index = [int(index) for index in sequence_index] | list_index = [int(index) for index in sequence_index] | ||||
| for i, index in enumerate(list_index): | for i, index in enumerate(list_index): | ||||
| @@ -512,7 +522,7 @@ def transform_sequence_index(sequence_index, shape, op_name): | |||||
| return sub_tuple_index | return sub_tuple_index | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name): | def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name): | ||||
| """Convert a slice to a tensor.""" | """Convert a slice to a tensor.""" | ||||
| shape = [] | shape = [] | ||||
| @@ -540,7 +550,7 @@ def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_n | |||||
| return tensor | return tensor | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def check_shapes_same(value_shapes, op_name): | def check_shapes_same(value_shapes, op_name): | ||||
| """Check if the shapes in the tuple are consistent.""" | """Check if the shapes in the tuple are consistent.""" | ||||
| for i, shape in enumerate(value_shapes): | for i, shape in enumerate(value_shapes): | ||||
| @@ -550,7 +560,7 @@ def check_shapes_same(value_shapes, op_name): | |||||
| return True | return True | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type): | def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type): | ||||
| """Convert a scalar to a tensor.""" | """Convert a scalar to a tensor.""" | ||||
| if op_type == SET_ITEM_BY_ONE_TENSOR: | if op_type == SET_ITEM_BY_ONE_TENSOR: | ||||
| @@ -563,7 +573,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty | |||||
| f" is not consistent with the assigned tensor data type {data_dtype}.") | f" is not consistent with the assigned tensor data type {data_dtype}.") | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type): | def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type): | ||||
| """Convert a tuple of scalar to a tensor.""" | """Convert a tuple of scalar to a tensor.""" | ||||
| updates_shape = generate_updates_shape(data_shape, index_shape, op_type) | updates_shape = generate_updates_shape(data_shape, index_shape, op_type) | ||||
| @@ -575,7 +585,7 @@ def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value | |||||
| return Tensor(np.tile(array, reps)) | return Tensor(np.tile(array, reps)) | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def generate_updates_shape(data_shape, index_shape, op_type): | def generate_updates_shape(data_shape, index_shape, op_type): | ||||
| """Generate updates shape for 'tensor setitem'.""" | """Generate updates shape for 'tensor setitem'.""" | ||||
| if op_type == SET_ITEM_BY_ONE_TENSOR: | if op_type == SET_ITEM_BY_ONE_TENSOR: | ||||
| @@ -585,7 +595,7 @@ def generate_updates_shape(data_shape, index_shape, op_type): | |||||
| return updates_shape | return updates_shape | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def check_tuple_index_len(data_rank, tuple_index_len, op_name): | def check_tuple_index_len(data_rank, tuple_index_len, op_name): | ||||
| """Check if the number of index tensor exceeds the dimension of the operated tensor.""" | """Check if the number of index tensor exceeds the dimension of the operated tensor.""" | ||||
| if tuple_index_len <= data_rank: | if tuple_index_len <= data_rank: | ||||
| @@ -594,7 +604,7 @@ def check_tuple_index_len(data_rank, tuple_index_len, op_name): | |||||
| f"is greater than the dimension {data_rank} of the operated tensor.") | f"is greater than the dimension {data_rank} of the operated tensor.") | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def generate_index_info_from_tuple_of_mixed_tensors(data_shape, indexes_types, tensor_indexes_shapes, | def generate_index_info_from_tuple_of_mixed_tensors(data_shape, indexes_types, tensor_indexes_shapes, | ||||
| tensor_indexes_dtypes, slice_indexes, op_name): | tensor_indexes_dtypes, slice_indexes, op_name): | ||||
| """ | """ | ||||
| @@ -694,14 +704,14 @@ def scalar_in_sequence(x, y): | |||||
| return False | return False | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def get_np_eps(input_dtype): | def get_np_eps(input_dtype): | ||||
| nptype = mstype.dtype_to_nptype(input_dtype) | nptype = mstype.dtype_to_nptype(input_dtype) | ||||
| eps = np.finfo(nptype).eps | eps = np.finfo(nptype).eps | ||||
| return float(eps) | return float(eps) | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def check_number_index_type(number): | def check_number_index_type(number): | ||||
| """Check if it is int or bool number""" | """Check if it is int or bool number""" | ||||
| if isinstance(number, bool): | if isinstance(number, bool): | ||||
| @@ -712,7 +722,7 @@ def check_number_index_type(number): | |||||
| .format(number, type(number))) | .format(number, type(number))) | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def get_stride_info_from_slice(data_shape, slice_index): | def get_stride_info_from_slice(data_shape, slice_index): | ||||
| """Get stride info from a python slice""" | """Get stride info from a python slice""" | ||||
| begin, end, step = get_slice_stride(data_shape[0], slice_index) | begin, end, step = get_slice_stride(data_shape[0], slice_index) | ||||
| @@ -726,7 +736,7 @@ def get_stride_info_from_slice(data_shape, slice_index): | |||||
| return tuple(begin_strides), tuple(end_strides), tuple(step_strides) | return tuple(begin_strides), tuple(end_strides), tuple(step_strides) | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def get_stride_info_from_integer(data_shape, number): | def get_stride_info_from_integer(data_shape, number): | ||||
| """Get stride info from a integer""" | """Get stride info from a integer""" | ||||
| begin_strides = [number] | begin_strides = [number] | ||||
| @@ -752,7 +762,7 @@ def get_slice_stride(dim_size, index_slice): | |||||
| return start, stop, step | return start, stop, step | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def get_stride_info_from_tuple(data_shape, tuple_index): | def get_stride_info_from_tuple(data_shape, tuple_index): | ||||
| """Get stride info from a tuple""" | """Get stride info from a tuple""" | ||||
| begin_strides, end_strides, step_strides = [], [], [] | begin_strides, end_strides, step_strides = [], [], [] | ||||
| @@ -792,14 +802,14 @@ def get_stride_info_from_tuple(data_shape, tuple_index): | |||||
| return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis | return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def mstype_eq(x, y): | def mstype_eq(x, y): | ||||
| if x == y: | if x == y: | ||||
| return True | return True | ||||
| return False | return False | ||||
| @ constexpr | |||||
| @constexpr | |||||
| def scalar_to_tensor(x): | def scalar_to_tensor(x): | ||||
| """Convert a scalar to a tensor""" | """Convert a scalar to a tensor""" | ||||
| return Tensor(x) | return Tensor(x) | ||||