Browse Source

!11205 master_tensor_getitem_debug

From: @yepei6
Reviewed-by: @kingxian,@zh_qh
Signed-off-by: @kingxian
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
0018070b54
1 changed files with 57 additions and 47 deletions
  1. +57
    -47
      mindspore/ops/composite/multitype_ops/_constexpr_utils.py

+ 57
- 47
mindspore/ops/composite/multitype_ops/_constexpr_utils.py View File

@@ -147,13 +147,15 @@ def judge_index_type(index_type, target_type):
@constexpr @constexpr
def check_type_valid(dtype, target_type, op_name): def check_type_valid(dtype, target_type, op_name):
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
raise TypeError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
raise TypeError(
f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")




@constexpr @constexpr
def check_index_type_valid(dtype, target_type, op_name): def check_index_type_valid(dtype, target_type, op_name):
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
raise IndexError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
raise IndexError(
f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")




@constexpr @constexpr
@@ -189,7 +191,8 @@ def get_pos_of_indexes_types(indexes_types, op_name):
raise IndexError(f"For '{op_name}', the index elements only support " raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.") f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.")
if len(ellipsis_positions) > 1: if len(ellipsis_positions) > 1:
raise IndexError(f"For '{op_name}, an index can only have a single ellipsis('...')")
raise IndexError(
f"For '{op_name}, an index can only have a single ellipsis('...')")


return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \ return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \
tensor_positions, sequence_positions tensor_positions, sequence_positions
@@ -260,7 +263,7 @@ def ellipsis2slice(input_, shape):
return tuple(result) return tuple(result)




@ constexpr
@constexpr
def slice2indices(input_slices, shape): def slice2indices(input_slices, shape):
""" """
Converts slice to indices. Converts slice to indices.
@@ -285,7 +288,7 @@ def slice2indices(input_slices, shape):
return ravel return ravel




@ constexpr
@constexpr
def check_indices(indices_size, index): def check_indices(indices_size, index):
"""Checks indices whether is empty.""" """Checks indices whether is empty."""
if indices_size < 1: if indices_size < 1:
@@ -294,7 +297,7 @@ def check_indices(indices_size, index):
return indices_size return indices_size




@ constexpr
@constexpr
def check_indices_value_size(indices_size, value_size): def check_indices_value_size(indices_size, value_size):
"""Checks if the sizes are already matched.""" """Checks if the sizes are already matched."""
if value_size < 1: if value_size < 1:
@@ -307,7 +310,7 @@ def check_indices_value_size(indices_size, value_size):
return value_size return value_size




@ constexpr
@constexpr
def integer_to_indices(index, shape): def integer_to_indices(index, shape):
"""Converts int or tuple[int] to indices.""" """Converts int or tuple[int] to indices."""
size = reduce(lambda x, y: x * y, shape) size = reduce(lambda x, y: x * y, shape)
@@ -317,7 +320,7 @@ def integer_to_indices(index, shape):
return Tensor(value, dtype=mstype.int32) return Tensor(value, dtype=mstype.int32)




@ constexpr
@constexpr
def tuple_element_is_int(indexs): def tuple_element_is_int(indexs):
"""Judges tuple element type.""" """Judges tuple element type."""
if not indexs: if not indexs:
@@ -330,18 +333,19 @@ def tuple_element_is_int(indexs):
return False return False




@ constexpr
@constexpr
def tuple_index_int_cnt(types, op_name): def tuple_index_int_cnt(types, op_name):
"""count the int type of types which contains the tuple elements' type.""" """count the int type of types which contains the tuple elements' type."""
int_cnt = sum(isinstance(ele, mstype.Int) for ele in types) int_cnt = sum(isinstance(ele, mstype.Int) for ele in types)
return ALL_INT if int_cnt == len(types) else NO_INT if int_cnt == 0 else CONTAIN_INT return ALL_INT if int_cnt == len(types) else NO_INT if int_cnt == 0 else CONTAIN_INT




@ constexpr
@constexpr
def tuple_index_type_cnt(types, op_name): def tuple_index_type_cnt(types, op_name):
"""count the tensor type of types which contains the tuple elements' type.""" """count the tensor type of types which contains the tuple elements' type."""
tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types) tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types)
basic_cnt = sum(isinstance(ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types)
basic_cnt = sum(isinstance(
ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types)
if tensor_cnt == len(types): if tensor_cnt == len(types):
return ALL_TENSOR return ALL_TENSOR
if basic_cnt == len(types): if basic_cnt == len(types):
@@ -349,7 +353,7 @@ def tuple_index_type_cnt(types, op_name):
return MIXED return MIXED




@ constexpr
@constexpr
def check_value_elements(data_dtype, types): def check_value_elements(data_dtype, types):
"""Judges the type of all elements of the tuple.""" """Judges the type of all elements of the tuple."""
tensors_number = 0 tensors_number = 0
@@ -377,10 +381,10 @@ def check_value_elements(data_dtype, types):
# TODO to del # TODO to del




@ constexpr
@constexpr
def get_index_tensor_dtype(dtype): def get_index_tensor_dtype(dtype):
"""Check a tuple of tensor data type.""" """Check a tuple of tensor data type."""
if dtype == mstype.int32:
if dtype in mstype.int_type:
return INT_ return INT_
if dtype == mstype.bool_: if dtype == mstype.bool_:
return BOOL_ return BOOL_
@@ -389,7 +393,7 @@ def get_index_tensor_dtype(dtype):




# TODO to del # TODO to del
@ constexpr
@constexpr
def check_index_tensors_dtype(indexes_types, op_name): def check_index_tensors_dtype(indexes_types, op_name):
"""Check a tuple of tensor data type.""" """Check a tuple of tensor data type."""
for index_type in indexes_types: for index_type in indexes_types:
@@ -400,7 +404,7 @@ def check_index_tensors_dtype(indexes_types, op_name):




# TODO to del # TODO to del
@ constexpr
@constexpr
def check_index_tensor_dtype(index_type, op_name): def check_index_tensor_dtype(index_type, op_name):
"""Check a tensor data type.""" """Check a tensor data type."""
if index_type in (mstype.int32, mstype.int64): if index_type in (mstype.int32, mstype.int64):
@@ -410,7 +414,7 @@ def check_index_tensor_dtype(index_type, op_name):




# TODO to del # TODO to del
@ constexpr
@constexpr
def check_tensors_dtype_same(data_dtype, value_dtype, op_name): def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
"""Check tensors data type same.""" """Check tensors data type same."""
if value_dtype == data_dtype: if value_dtype == data_dtype:
@@ -419,7 +423,7 @@ def check_tensors_dtype_same(data_dtype, value_dtype, op_name):
f"is not consistent with assigned tensor data type {data_dtype}.") f"is not consistent with assigned tensor data type {data_dtype}.")




@ constexpr
@constexpr
def generate_broadcast_shape(shapes, op_name): def generate_broadcast_shape(shapes, op_name):
"""Generate broadcast shape for a tuple of shape.""" """Generate broadcast shape for a tuple of shape."""
if not shapes: if not shapes:
@@ -428,13 +432,14 @@ def generate_broadcast_shape(shapes, op_name):
for i, shape in enumerate(shapes): for i, shape in enumerate(shapes):
logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.") logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
try: try:
broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name)
broadcast_shape = op_utils.get_broadcast_shape(
broadcast_shape, shape, op_name)
except ValueError as ex: except ValueError as ex:
raise IndexError(ex) raise IndexError(ex)
return tuple(broadcast_shape) return tuple(broadcast_shape)




@ constexpr
@constexpr
def check_two_shapes_need_broadcast(shape_x, shape_y): def check_two_shapes_need_broadcast(shape_x, shape_y):
"""Check two shapes need broadcast.""" """Check two shapes need broadcast."""
error = ValueError(f"For 'tensor setitem with tensor', the value tensor shape " error = ValueError(f"For 'tensor setitem with tensor', the value tensor shape "
@@ -451,14 +456,14 @@ def check_two_shapes_need_broadcast(shape_x, shape_y):
return True return True




@ constexpr
@constexpr
def compute_multiples(origin_shape, broadcast_shape): def compute_multiples(origin_shape, broadcast_shape):
"""Compute multiples between origin shape with broadcast shape.""" """Compute multiples between origin shape with broadcast shape."""
len_gap = len(broadcast_shape) - len(origin_shape) len_gap = len(broadcast_shape) - len(origin_shape)
return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape)) return broadcast_shape[0:len_gap] + tuple(map(lambda x, y: x // y, broadcast_shape[len_gap:], origin_shape))




@ constexpr
@constexpr
def compute_new_shape(origin_shape, indexes_shapes_info): def compute_new_shape(origin_shape, indexes_shapes_info):
"""Compute new shape between origin shape with final shape.""" """Compute new shape between origin shape with final shape."""
new_shape = [] new_shape = []
@@ -470,21 +475,22 @@ def compute_new_shape(origin_shape, indexes_shapes_info):
return tuple(new_shape) return tuple(new_shape)




@ constexpr
@constexpr
def check_sequence_index_type(sequence_index, op_name): def check_sequence_index_type(sequence_index, op_name):
"""check if the item's type of list_index is bool or int""" """check if the item's type of list_index is bool or int"""
if not all([isinstance(index, (int, bool)) for index in sequence_index]):
raise IndexError(f"In the {op_name} operation, only support 'integer' or 'boolean' array(list/tuple), "
f"but got {type(index)} in array")
for index in sequence_index:
if not isinstance(index, int):
raise IndexError(f"In the {op_name} operation, only support 'inter' or 'boolean' array(list/tuple), "
f"but got {type(index)} in array.")




@ constexpr
@constexpr
def convert_int_to_slice(tuple_index): def convert_int_to_slice(tuple_index):
tuple_index_new = tuple(slice(i, i+1, 1) for i in tuple_index) tuple_index_new = tuple(slice(i, i+1, 1) for i in tuple_index)
return tuple_index_new return tuple_index_new




@ constexpr
@constexpr
def check_and_transform_int_index(index, shape, op_name): def check_and_transform_int_index(index, shape, op_name):
if index < -shape or index >= shape: if index < -shape or index >= shape:
raise IndexError(f"In the \"{op_name}\", the index should in the range [-{shape}, {shape-1}] to fit " raise IndexError(f"In the \"{op_name}\", the index should in the range [-{shape}, {shape-1}] to fit "
@@ -494,16 +500,20 @@ def check_and_transform_int_index(index, shape, op_name):
return index return index




@ constexpr
@constexpr
def transform_sequence_index(sequence_index, shape, op_name): def transform_sequence_index(sequence_index, shape, op_name):
"""transform list or tuple with integer and boolean to tuple with integer index""" """transform list or tuple with integer and boolean to tuple with integer index"""
bool_count = len(list(filter(lambda index: isinstance(index, bool), sequence_index)))
int_count = len(list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count
bool_count = len(
list(filter(lambda index: isinstance(index, bool), sequence_index)))
int_count = len(
list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count
if int_count == 0: if int_count == 0:
if bool_count == shape: if bool_count == shape:
list_index = list(filter(lambda i: sequence_index[i], range(bool_count)))
list_index = list(
filter(lambda i: sequence_index[i], range(bool_count)))
else: else:
raise IndexError("The boolean array should have the same length with the corresponding dimensiton")
raise IndexError(
"The boolean array should have the same length with the corresponding dimensiton")
else: else:
list_index = [int(index) for index in sequence_index] list_index = [int(index) for index in sequence_index]
for i, index in enumerate(list_index): for i, index in enumerate(list_index):
@@ -512,7 +522,7 @@ def transform_sequence_index(sequence_index, shape, op_name):
return sub_tuple_index return sub_tuple_index




@ constexpr
@constexpr
def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name): def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_name):
"""Convert a slice to a tensor.""" """Convert a slice to a tensor."""
shape = [] shape = []
@@ -540,7 +550,7 @@ def convert_slice_to_tensor(slice_number, final_shape, indexes_shapes_info, op_n
return tensor return tensor




@ constexpr
@constexpr
def check_shapes_same(value_shapes, op_name): def check_shapes_same(value_shapes, op_name):
"""Check if the shapes in the tuple are consistent.""" """Check if the shapes in the tuple are consistent."""
for i, shape in enumerate(value_shapes): for i, shape in enumerate(value_shapes):
@@ -550,7 +560,7 @@ def check_shapes_same(value_shapes, op_name):
return True return True




@ constexpr
@constexpr
def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type): def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_type):
"""Convert a scalar to a tensor.""" """Convert a scalar to a tensor."""
if op_type == SET_ITEM_BY_ONE_TENSOR: if op_type == SET_ITEM_BY_ONE_TENSOR:
@@ -563,7 +573,7 @@ def convert_scalar_to_tensor(data_shape, data_dtype, indices_shape, value, op_ty
f" is not consistent with the assigned tensor data type {data_dtype}.") f" is not consistent with the assigned tensor data type {data_dtype}.")




@ constexpr
@constexpr
def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type): def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value, op_type):
"""Convert a tuple of scalar to a tensor.""" """Convert a tuple of scalar to a tensor."""
updates_shape = generate_updates_shape(data_shape, index_shape, op_type) updates_shape = generate_updates_shape(data_shape, index_shape, op_type)
@@ -575,7 +585,7 @@ def convert_tuple_of_scalar_to_tensor(data_shape, data_dtype, index_shape, value
return Tensor(np.tile(array, reps)) return Tensor(np.tile(array, reps))




@ constexpr
@constexpr
def generate_updates_shape(data_shape, index_shape, op_type): def generate_updates_shape(data_shape, index_shape, op_type):
"""Generate updates shape for 'tensor setitem'.""" """Generate updates shape for 'tensor setitem'."""
if op_type == SET_ITEM_BY_ONE_TENSOR: if op_type == SET_ITEM_BY_ONE_TENSOR:
@@ -585,7 +595,7 @@ def generate_updates_shape(data_shape, index_shape, op_type):
return updates_shape return updates_shape




@ constexpr
@constexpr
def check_tuple_index_len(data_rank, tuple_index_len, op_name): def check_tuple_index_len(data_rank, tuple_index_len, op_name):
"""Check if the number of index tensor exceeds the dimension of the operated tensor.""" """Check if the number of index tensor exceeds the dimension of the operated tensor."""
if tuple_index_len <= data_rank: if tuple_index_len <= data_rank:
@@ -594,7 +604,7 @@ def check_tuple_index_len(data_rank, tuple_index_len, op_name):
f"is greater than the dimension {data_rank} of the operated tensor.") f"is greater than the dimension {data_rank} of the operated tensor.")




@ constexpr
@constexpr
def generate_index_info_from_tuple_of_mixed_tensors(data_shape, indexes_types, tensor_indexes_shapes, def generate_index_info_from_tuple_of_mixed_tensors(data_shape, indexes_types, tensor_indexes_shapes,
tensor_indexes_dtypes, slice_indexes, op_name): tensor_indexes_dtypes, slice_indexes, op_name):
""" """
@@ -694,14 +704,14 @@ def scalar_in_sequence(x, y):
return False return False




@ constexpr
@constexpr
def get_np_eps(input_dtype): def get_np_eps(input_dtype):
nptype = mstype.dtype_to_nptype(input_dtype) nptype = mstype.dtype_to_nptype(input_dtype)
eps = np.finfo(nptype).eps eps = np.finfo(nptype).eps
return float(eps) return float(eps)




@ constexpr
@constexpr
def check_number_index_type(number): def check_number_index_type(number):
"""Check if it is int or bool number""" """Check if it is int or bool number"""
if isinstance(number, bool): if isinstance(number, bool):
@@ -712,7 +722,7 @@ def check_number_index_type(number):
.format(number, type(number))) .format(number, type(number)))




@ constexpr
@constexpr
def get_stride_info_from_slice(data_shape, slice_index): def get_stride_info_from_slice(data_shape, slice_index):
"""Get stride info from a python slice""" """Get stride info from a python slice"""
begin, end, step = get_slice_stride(data_shape[0], slice_index) begin, end, step = get_slice_stride(data_shape[0], slice_index)
@@ -726,7 +736,7 @@ def get_stride_info_from_slice(data_shape, slice_index):
return tuple(begin_strides), tuple(end_strides), tuple(step_strides) return tuple(begin_strides), tuple(end_strides), tuple(step_strides)




@ constexpr
@constexpr
def get_stride_info_from_integer(data_shape, number): def get_stride_info_from_integer(data_shape, number):
"""Get stride info from a integer""" """Get stride info from a integer"""
begin_strides = [number] begin_strides = [number]
@@ -752,7 +762,7 @@ def get_slice_stride(dim_size, index_slice):
return start, stop, step return start, stop, step




@ constexpr
@constexpr
def get_stride_info_from_tuple(data_shape, tuple_index): def get_stride_info_from_tuple(data_shape, tuple_index):
"""Get stride info from a tuple""" """Get stride info from a tuple"""
begin_strides, end_strides, step_strides = [], [], [] begin_strides, end_strides, step_strides = [], [], []
@@ -792,14 +802,14 @@ def get_stride_info_from_tuple(data_shape, tuple_index):
return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis return tuple(begin_strides), tuple(end_strides), tuple(step_strides), shrink_axis




@ constexpr
@constexpr
def mstype_eq(x, y): def mstype_eq(x, y):
if x == y: if x == y:
return True return True
return False return False




@ constexpr
@constexpr
def scalar_to_tensor(x): def scalar_to_tensor(x):
"""Convert a scalar to a tensor""" """Convert a scalar to a tensor"""
return Tensor(x) return Tensor(x)

Loading…
Cancel
Save