Browse Source

change error type

tags/v0.2.0-alpha
zhaojichen 5 years ago
parent
commit
8f1d140de1
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      mindspore/_checkparam.py

+ 3
- 3
mindspore/_checkparam.py View File

@@ -304,9 +304,9 @@ class Validator:
type_names = [get_typename(t) for t in valid_types] type_names = [get_typename(t) for t in valid_types]
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
if len(valid_types) == 1: if len(valid_types) == 1:
raise ValueError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},'
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},'
f' but got {get_typename(arg_type)}.') f' but got {get_typename(arg_type)}.')
raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},'
f' but got {get_typename(arg_type)}.') f' but got {get_typename(arg_type)}.')


@staticmethod @staticmethod
@@ -417,7 +417,7 @@ class ParamValidator:
"""func for raising error message when check failed""" """func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types] type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types) num_types = len(valid_types)
raise ValueError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')


if isinstance(arg_value, type(mstype.tensor)): if isinstance(arg_value, type(mstype.tensor)):


Loading…
Cancel
Save