|
|
|
@@ -49,16 +49,19 @@ SET_ITEM_BY_NON_TENSOR = 2 |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def raise_value_error(msg): |
|
|
|
"""Constexpr for raise_value_error.""" |
|
|
|
raise ValueError(msg) |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def raise_index_error(msg): |
|
|
|
"""Constexpr for raise_index_error.""" |
|
|
|
raise IndexError(msg) |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def raise_type_error(msg): |
|
|
|
"""Constexpr for raise_type_error.""" |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
|
|
|
|
@@ -77,6 +80,7 @@ def check_equal(param1, param2, msg="{},{}"): |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def make_empty_slice(): |
|
|
|
"""Creates a empty slice.""" |
|
|
|
return slice(None, None, None) |
|
|
|
|
|
|
|
|
|
|
|
@@ -179,6 +183,7 @@ tensor_operator_registry.register('make_tensor', make_tensor) |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def judge_data_dim(data_dim, min_data_dim=0, max_data_dim=8): |
|
|
|
"""Judges whether the data dim is valid.""" |
|
|
|
if data_dim < min_data_dim or data_dim > max_data_dim: |
|
|
|
raise ValueError(f"The input data's dim should in the range of[{min_data_dim}, " |
|
|
|
f"{max_data_dim}], bug actually is '{data_dim}'") |
|
|
|
@@ -244,12 +249,14 @@ def is_same_type(inst, type_): |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def check_valid_dim(dim, name): |
|
|
|
"""Checks whether the dim is valid.""" |
|
|
|
if dim not in (1, 2): |
|
|
|
raise ValueError(f"For {name}, inputs dim must be 1d or 2d") |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def judge_index_type(index_type, target_type): |
|
|
|
"""Judges whether the index type is valid.""" |
|
|
|
if index_type == target_type or (isinstance(target_type, (list, tuple)) and index_type in target_type): |
|
|
|
return True |
|
|
|
return False |
|
|
|
@@ -270,6 +277,7 @@ def judge_indexes_types(dtypes, target_type): |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def check_type_valid(dtype, target_type, op_name): |
|
|
|
"""Checks whether the dtype is valid.""" |
|
|
|
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type): |
|
|
|
if op_name in (TENSOR_GETITEM, TENSOR_SETITEM): |
|
|
|
raise IndexError( |
|
|
|
@@ -476,6 +484,7 @@ def generate_updates_shape(data_shape, index_shape, op_type): |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def transform_slice_to_ele_list(slice_index, dim_len): |
|
|
|
"""Transforms slice to element list.""" |
|
|
|
slice_obj = slice(slice_index.start, slice_index.stop, slice_index.step) |
|
|
|
start, stop, end = normalize_slice(slice_obj, dim_len) |
|
|
|
slice_ele_list = list(range(start, stop, end)) |
|
|
|
@@ -528,6 +537,7 @@ def scalar_in_sequence(x, y): |
|
|
|
|
|
|
|
@constexpr |
|
|
|
def get_np_eps(input_dtype): |
|
|
|
"""Get numpy eps.""" |
|
|
|
nptype = mstype.dtype_to_nptype(input_dtype) |
|
|
|
eps = np.finfo(nptype).eps |
|
|
|
return float(eps) |
|
|
|
|