|
|
|
@@ -366,20 +366,32 @@ def check_type_same(x_type, base_type): |
|
|
|
Tensor: mstype.tensor_type, |
|
|
|
Parameter: mstype.ref_type |
|
|
|
} |
|
|
|
try: |
|
|
|
if isinstance(base_type, list): |
|
|
|
raise TypeError("The second arg of 'isinstance' must be a type or a tuple of types, but got a list") |
|
|
|
if isinstance(base_type, tuple): |
|
|
|
target_type = tuple(pytype_to_mstype[i] for i in base_type) |
|
|
|
else: |
|
|
|
target_type = (pytype_to_mstype[base_type],) |
|
|
|
if (isinstance(x_type, mstype.Bool) and mstype.Int in target_type) or \ |
|
|
|
(isinstance(x_type, mstype.ref_type) and mstype.tensor_type in target_type): |
|
|
|
return True |
|
|
|
return isinstance(x_type, target_type) |
|
|
|
except KeyError: |
|
|
|
raise TypeError(f"The second arg of 'isinstance' should be bool, int, float, str, list, tuple, " |
|
|
|
f"Tensor, Parameter, or a tuple containing only these types, but got {base_type}") |
|
|
|
|
|
|
|
has_int = False |
|
|
|
has_tensor = False |
|
|
|
|
|
|
|
def to_target_type(origin_type): |
|
|
|
try: |
|
|
|
if isinstance(origin_type, type): |
|
|
|
ret_type = pytype_to_mstype[origin_type] |
|
|
|
if ret_type == mstype.Int: |
|
|
|
nonlocal has_int |
|
|
|
has_int = True |
|
|
|
if ret_type == mstype.tensor_type: |
|
|
|
nonlocal has_tensor |
|
|
|
has_tensor = True |
|
|
|
return (ret_type,) |
|
|
|
if isinstance(origin_type, tuple): |
|
|
|
return tuple(to_target_type(i) for i in origin_type) |
|
|
|
raise TypeError(f"The second arg of 'isinstance' must be a type or a tuple of types, " |
|
|
|
f"but got a {type(origin_type).__name__}") |
|
|
|
except KeyError: |
|
|
|
raise TypeError(f"The second arg of 'isinstance' should be bool, int, float, str, list, tuple, " |
|
|
|
f"Tensor, Parameter, or a tuple containing only these types, but got {origin_type}") |
|
|
|
target_type = to_target_type(base_type) |
|
|
|
if (isinstance(x_type, mstype.Bool) and has_int) or (isinstance(x_type, mstype.ref_type) and has_tensor): |
|
|
|
return True |
|
|
|
return isinstance(x_type, target_type) |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
|