diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 9c19d6833e..c554823df0 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -75,12 +75,12 @@ rel_fns = { rel_strs = { # scalar compare - Rel.EQ: "equal to {}", - Rel.NE: "not equal to {}", - Rel.LT: "less than {}", - Rel.LE: "less or equal to {}", - Rel.GT: "greater than {}", - Rel.GE: "greater or equal to {}", + Rel.EQ: "== {}", + Rel.NE: "!= {}", + Rel.LT: "< {}", + Rel.LE: "<= {}", + Rel.GT: "> {}", + Rel.GE: ">= {}", # scalar range check Rel.INC_NEITHER: "({}, {})", Rel.INC_LEFT: "[{}, {})", @@ -102,12 +102,16 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N rel_fn = Rel.get_fns(rel) type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool) type_except = TypeError if type_mismatch else ValueError + + prim_name = f'in `{prim_name}`' if prim_name else '' + arg_name = f'`{arg_name}`' if arg_name else '' + if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value): + raise ValueError(f'{arg_name} {prim_name} must be legal value, but got `{arg_value}`.') if type_mismatch or not rel_fn(arg_value, value): rel_str = Rel.get_strs(rel).format(value) - arg_name = arg_name if arg_name else "parameter" - msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" - raise type_except(f'{msg_prefix} `{arg_name}` should be an {arg_type} and must {rel_str}, but got `{arg_value}`' - f' with type `{type(arg_value).__name__}`.') + raise type_except(f'{arg_name} {prim_name} should be an {type(arg_type).__name__} and must {rel_str}, ' + f'but got `{arg_value}` with type `{type(arg_value).__name__}`.') + return arg_value @@ -123,7 +127,7 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None): prim_name = f'in \'{prim_name}\'' if prim_name else '' arg_name = f'\'{prim_name}\'' if arg_name else 'Input value' if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool): - if math.isinf(arg_value) or math.isnan(arg_value): + if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value): raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.') return arg_value raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`') @@ -137,14 +141,15 @@ def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg - number = check_number_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number", float) # number in [0.0, 1.0] - number = check_number_range(number, 0, 1, Rel.INC_NEITHER, "number", int) # number in [0, 1] """ + rel_fn = Rel.get_fns(rel) prim_name = f'in `{prim_name}`' if prim_name else '' arg_name = f'`{arg_name}`' if arg_name else '' - rel_fn = Rel.get_fns(rel) type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool) - excp_cls = TypeError if type_mismatch else ValueError - if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit): + if type_mismatch: + raise TypeError(f'{arg_name} {prim_name} must be `{value_type}`, but got `{type(arg_value).__name__}`.') + if not rel_fn(arg_value, lower_limit, upper_limit): rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) - raise excp_cls("{} {} should be in range of {}, but got {:.3f} with type {}.".format( + raise ValueError("{} {} should be in range of {}, but got {:.3e} with type `{}`.".format( arg_name, prim_name, rel_str, arg_value, type(arg_value).__name__)) return arg_value