Merge pull request !611 from JichenZhao/errormsgtags/v0.2.0-alpha
| @@ -304,10 +304,10 @@ class Validator: | |||||
| type_names = [get_typename(t) for t in valid_types] | type_names = [get_typename(t) for t in valid_types] | ||||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' | msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' | ||||
| if len(valid_types) == 1: | if len(valid_types) == 1: | ||||
| raise ValueError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},' | |||||
| f' but got {get_typename(arg_type)}.') | |||||
| raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' | |||||
| f' but got {get_typename(arg_type)}.') | |||||
| raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},' | |||||
| f' but got {get_typename(arg_type)}.') | |||||
| raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' | |||||
| f' but got {get_typename(arg_type)}.') | |||||
| @staticmethod | @staticmethod | ||||
| def check_float_legal_value(arg_name, arg_value, prim_name): | def check_float_legal_value(arg_name, arg_value, prim_name): | ||||
| @@ -417,8 +417,8 @@ class ParamValidator: | |||||
| """func for raising error message when check failed""" | """func for raising error message when check failed""" | ||||
| type_names = [t.__name__ for t in valid_types] | type_names = [t.__name__ for t in valid_types] | ||||
| num_types = len(valid_types) | num_types = len(valid_types) | ||||
| raise ValueError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' | |||||
| f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | |||||
| raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' | |||||
| f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | |||||
| if isinstance(arg_value, type(mstype.tensor)): | if isinstance(arg_value, type(mstype.tensor)): | ||||
| arg_value = arg_value.element_type() | arg_value = arg_value.element_type() | ||||
| @@ -228,10 +228,10 @@ def test_exec(): | |||||
| raise_set = [ | raise_set = [ | ||||
| ('Squeeze_1_Error', { | ('Squeeze_1_Error', { | ||||
| 'block': (lambda x: P.Squeeze(axis=1.2), {'exception': ValueError}), | |||||
| 'block': (lambda x: P.Squeeze(axis=1.2), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}), | 'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}), | ||||
| ('Squeeze_2_Error', { | ('Squeeze_2_Error', { | ||||
| 'block': (lambda x: P.Squeeze(axis=((1.2, 1.3))), {'exception': ValueError}), | |||||
| 'block': (lambda x: P.Squeeze(axis=((1.2, 1.3))), {'exception': TypeError}), | |||||
| 'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}), | 'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}), | ||||
| ('ReduceSum_Error', { | ('ReduceSum_Error', { | ||||
| 'block': (lambda x: P.ReduceSum(keep_dims=1), {'exception': TypeError}), | 'block': (lambda x: P.ReduceSum(keep_dims=1), {'exception': TypeError}), | ||||
| @@ -401,16 +401,16 @@ def test_exec(): | |||||
| raise_set = [ | raise_set = [ | ||||
| ('StridedSlice_1_Error', { | ('StridedSlice_1_Error', { | ||||
| 'block': (lambda x: P.StridedSlice(begin_mask="1"), {'exception': ValueError}), | |||||
| 'block': (lambda x: P.StridedSlice(begin_mask="1"), {'exception': TypeError}), | |||||
| 'desc_inputs': [0]}), | 'desc_inputs': [0]}), | ||||
| ('StridedSlice_2_Error', { | ('StridedSlice_2_Error', { | ||||
| 'block': (lambda x: P.StridedSlice(end_mask="1"), {'exception': ValueError}), | |||||
| 'block': (lambda x: P.StridedSlice(end_mask="1"), {'exception': TypeError}), | |||||
| 'desc_inputs': [0]}), | 'desc_inputs': [0]}), | ||||
| ('StridedSlice_3_Error', { | ('StridedSlice_3_Error', { | ||||
| 'block': (lambda x: P.StridedSlice(ellipsis_mask=1.1), {'exception': ValueError}), | |||||
| 'block': (lambda x: P.StridedSlice(ellipsis_mask=1.1), {'exception': TypeError}), | |||||
| 'desc_inputs': [0]}), | 'desc_inputs': [0]}), | ||||
| ('StridedSlice_4_Error', { | ('StridedSlice_4_Error', { | ||||
| 'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': ValueError}), | |||||
| 'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': TypeError}), | |||||
| 'desc_inputs': [0]}), | 'desc_inputs': [0]}), | ||||
| ] | ] | ||||