| @@ -366,20 +366,32 @@ def check_type_same(x_type, base_type): | |||||
| Tensor: mstype.tensor_type, | Tensor: mstype.tensor_type, | ||||
| Parameter: mstype.ref_type | Parameter: mstype.ref_type | ||||
| } | } | ||||
| try: | |||||
| if isinstance(base_type, list): | |||||
| raise TypeError("The second arg of 'isinstance' must be a type or a tuple of types, but got a list") | |||||
| if isinstance(base_type, tuple): | |||||
| target_type = tuple(pytype_to_mstype[i] for i in base_type) | |||||
| else: | |||||
| target_type = (pytype_to_mstype[base_type],) | |||||
| if (isinstance(x_type, mstype.Bool) and mstype.Int in target_type) or \ | |||||
| (isinstance(x_type, mstype.ref_type) and mstype.tensor_type in target_type): | |||||
| return True | |||||
| return isinstance(x_type, target_type) | |||||
| except KeyError: | |||||
| raise TypeError(f"The second arg of 'isinstance' should be bool, int, float, str, list, tuple, " | |||||
| f"Tensor, Parameter, or a tuple containing only these types, but got {base_type}") | |||||
| has_int = False | |||||
| has_tensor = False | |||||
| def to_target_type(origin_type): | |||||
| try: | |||||
| if isinstance(origin_type, type): | |||||
| ret_type = pytype_to_mstype[origin_type] | |||||
| if ret_type == mstype.Int: | |||||
| nonlocal has_int | |||||
| has_int = True | |||||
| if ret_type == mstype.tensor_type: | |||||
| nonlocal has_tensor | |||||
| has_tensor = True | |||||
| return (ret_type,) | |||||
| if isinstance(origin_type, tuple): | |||||
| return tuple(to_target_type(i) for i in origin_type) | |||||
| raise TypeError(f"The second arg of 'isinstance' must be a type or a tuple of types, " | |||||
| f"but got a {type(origin_type).__name__}") | |||||
| except KeyError: | |||||
| raise TypeError(f"The second arg of 'isinstance' should be bool, int, float, str, list, tuple, " | |||||
| f"Tensor, Parameter, or a tuple containing only these types, but got {origin_type}") | |||||
| target_type = to_target_type(base_type) | |||||
| if (isinstance(x_type, mstype.Bool) and has_int) or (isinstance(x_type, mstype.ref_type) and has_tensor): | |||||
| return True | |||||
| return isinstance(x_type, target_type) | |||||
| @constexpr | @constexpr | ||||
| @@ -43,10 +43,10 @@ def test_isinstance(): | |||||
| is_int = isinstance(self.int_member, int) | is_int = isinstance(self.int_member, int) | ||||
| is_float = isinstance(self.float_member, float) | is_float = isinstance(self.float_member, float) | ||||
| is_bool = isinstance(self.bool_member, bool) | is_bool = isinstance(self.bool_member, bool) | ||||
| bool_is_int = isinstance(self.bool_member, int) | |||||
| bool_is_int = isinstance(self.bool_member, (((int,)), float)) | |||||
| is_string = isinstance(self.string_member, str) | is_string = isinstance(self.string_member, str) | ||||
| is_parameter = isinstance(self.weight, Parameter) | is_parameter = isinstance(self.weight, Parameter) | ||||
| parameter_is_tensor = isinstance(self.weight, Tensor) | |||||
| parameter_is_tensor = isinstance(self.weight, ((Tensor, float), int)) | |||||
| is_tensor_const = isinstance(self.tensor_member, Tensor) | is_tensor_const = isinstance(self.tensor_member, Tensor) | ||||
| is_tensor_var = isinstance(x, Tensor) | is_tensor_var = isinstance(x, Tensor) | ||||
| is_tuple_const = isinstance(self.tuple_member, tuple) | is_tuple_const = isinstance(self.tuple_member, tuple) | ||||
| @@ -88,8 +88,7 @@ def test_isinstance_not_supported(): | |||||
| net = Net() | net = Net() | ||||
| with pytest.raises(TypeError) as err: | with pytest.raises(TypeError) as err: | ||||
| net() | net() | ||||
| assert "The second arg of 'isinstance' should be bool, int, float, str, list, tuple, Tensor, Parameter, " \ | |||||
| "or a tuple containing only these types, but got None" in str(err.value) | |||||
| assert "The second arg of 'isinstance' must be a type or a tuple of types, but got a NoneType" in str(err.value) | |||||
| def test_isinstance_second_arg_is_list(): | def test_isinstance_second_arg_is_list(): | ||||