From 8f1d140de1a00ce2ffcb924abfdcaa50aee906df Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 23 Apr 2020 04:14:47 -0400 Subject: [PATCH 1/3] change error type --- mindspore/_checkparam.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 3543f58cf5..dba1c13b3b 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -304,9 +304,9 @@ class Validator: type_names = [get_typename(t) for t in valid_types] msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' if len(valid_types) == 1: - raise ValueError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},' + raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},' f' but got {get_typename(arg_type)}.') - raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' + raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' f' but got {get_typename(arg_type)}.') @staticmethod @@ -417,7 +417,7 @@ class ParamValidator: """func for raising error message when check failed""" type_names = [t.__name__ for t in valid_types] num_types = len(valid_types) - raise ValueError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' + raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') if isinstance(arg_value, type(mstype.tensor)): From ff57caceb98c1261642c7f48a60a3116435f3c8d Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 23 Apr 2020 04:31:14 -0400 Subject: [PATCH 2/3] change error type --- mindspore/_checkparam.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index dba1c13b3b..9ecf0c9e24 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -305,9 +305,9 @@ class Validator: msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' if len(valid_types) == 1: raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},' - f' but got {get_typename(arg_type)}.') + f' but got {get_typename(arg_type)}.') raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' - f' but got {get_typename(arg_type)}.') + f' but got {get_typename(arg_type)}.') @staticmethod def check_float_legal_value(arg_name, arg_value, prim_name): @@ -418,7 +418,7 @@ class ParamValidator: type_names = [t.__name__ for t in valid_types] num_types = len(valid_types) raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' - f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') + f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') if isinstance(arg_value, type(mstype.tensor)): arg_value = arg_value.element_type() From 8963e395163fabecf8637a19729bbc14ff52c2fd Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 23 Apr 2020 05:29:26 -0400 Subject: [PATCH 3/3] change error type --- tests/ut/python/ops/test_array_ops.py | 4 ++-- tests/ut/python/ops/test_math_ops.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/ut/python/ops/test_array_ops.py b/tests/ut/python/ops/test_array_ops.py index faaa9d5402..01e7e32d50 100644 --- a/tests/ut/python/ops/test_array_ops.py +++ b/tests/ut/python/ops/test_array_ops.py @@ -228,10 +228,10 @@ def test_exec(): raise_set = [ ('Squeeze_1_Error', { - 'block': (lambda x: P.Squeeze(axis=1.2), {'exception': ValueError}), + 'block': (lambda x: P.Squeeze(axis=1.2), {'exception': TypeError}), 'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}), ('Squeeze_2_Error', { - 'block': (lambda x: P.Squeeze(axis=((1.2, 1.3))), {'exception': ValueError}), + 'block': (lambda x: P.Squeeze(axis=((1.2, 1.3))), {'exception': TypeError}), 'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}), ('ReduceSum_Error', { 'block': (lambda x: P.ReduceSum(keep_dims=1), {'exception': TypeError}), diff --git a/tests/ut/python/ops/test_math_ops.py b/tests/ut/python/ops/test_math_ops.py index 7f8717d4e6..b866c7c556 100755 --- a/tests/ut/python/ops/test_math_ops.py +++ b/tests/ut/python/ops/test_math_ops.py @@ -401,16 +401,16 @@ def test_exec(): raise_set = [ ('StridedSlice_1_Error', { - 'block': (lambda x: P.StridedSlice(begin_mask="1"), {'exception': ValueError}), + 'block': (lambda x: P.StridedSlice(begin_mask="1"), {'exception': TypeError}), 'desc_inputs': [0]}), ('StridedSlice_2_Error', { - 'block': (lambda x: P.StridedSlice(end_mask="1"), {'exception': ValueError}), + 'block': (lambda x: P.StridedSlice(end_mask="1"), {'exception': TypeError}), 'desc_inputs': [0]}), ('StridedSlice_3_Error', { - 'block': (lambda x: P.StridedSlice(ellipsis_mask=1.1), {'exception': ValueError}), + 'block': (lambda x: P.StridedSlice(ellipsis_mask=1.1), {'exception': TypeError}), 'desc_inputs': [0]}), ('StridedSlice_4_Error', { - 'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': ValueError}), + 'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': TypeError}), 'desc_inputs': [0]}), ]