Browse Source

!611 change error message type

Merge pull request !611 from JichenZhao/errormsg
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 6 years ago
parent
commit
ad75618ca6
3 changed files with 12 additions and 12 deletions
  1. +6
    -6
      mindspore/_checkparam.py
  2. +2
    -2
      tests/ut/python/ops/test_array_ops.py
  3. +4
    -4
      tests/ut/python/ops/test_math_ops.py

+ 6
- 6
mindspore/_checkparam.py View File

@@ -304,10 +304,10 @@ class Validator:
type_names = [get_typename(t) for t in valid_types]
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
if len(valid_types) == 1:
raise ValueError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.')
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.')

@staticmethod
def check_float_legal_value(arg_name, arg_value, prim_name):
@@ -417,8 +417,8 @@ class ParamValidator:
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
raise ValueError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')

if isinstance(arg_value, type(mstype.tensor)):
arg_value = arg_value.element_type()


+ 2
- 2
tests/ut/python/ops/test_array_ops.py View File

@@ -228,10 +228,10 @@ def test_exec():

raise_set = [
('Squeeze_1_Error', {
'block': (lambda x: P.Squeeze(axis=1.2), {'exception': ValueError}),
'block': (lambda x: P.Squeeze(axis=1.2), {'exception': TypeError}),
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}),
('Squeeze_2_Error', {
'block': (lambda x: P.Squeeze(axis=((1.2, 1.3))), {'exception': ValueError}),
'block': (lambda x: P.Squeeze(axis=((1.2, 1.3))), {'exception': TypeError}),
'desc_inputs': [Tensor(np.ones(shape=[3, 1, 5]))]}),
('ReduceSum_Error', {
'block': (lambda x: P.ReduceSum(keep_dims=1), {'exception': TypeError}),


+ 4
- 4
tests/ut/python/ops/test_math_ops.py View File

@@ -401,16 +401,16 @@ def test_exec():

raise_set = [
('StridedSlice_1_Error', {
'block': (lambda x: P.StridedSlice(begin_mask="1"), {'exception': ValueError}),
'block': (lambda x: P.StridedSlice(begin_mask="1"), {'exception': TypeError}),
'desc_inputs': [0]}),
('StridedSlice_2_Error', {
'block': (lambda x: P.StridedSlice(end_mask="1"), {'exception': ValueError}),
'block': (lambda x: P.StridedSlice(end_mask="1"), {'exception': TypeError}),
'desc_inputs': [0]}),
('StridedSlice_3_Error', {
'block': (lambda x: P.StridedSlice(ellipsis_mask=1.1), {'exception': ValueError}),
'block': (lambda x: P.StridedSlice(ellipsis_mask=1.1), {'exception': TypeError}),
'desc_inputs': [0]}),
('StridedSlice_4_Error', {
'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': ValueError}),
'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': TypeError}),
'desc_inputs': [0]}),
]



Loading…
Cancel
Save