| @@ -92,7 +92,7 @@ rel_strs = { | |||
| } | |||
| def _check_integer(arg_value, value, rel, arg_name=None, prim_name=None): | |||
| def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None): | |||
| """ | |||
| Check argument integer. | |||
| @@ -100,13 +100,13 @@ def _check_integer(arg_value, value, rel, arg_name=None, prim_name=None): | |||
| - number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0 | |||
| """ | |||
| rel_fn = Rel.get_fns(rel) | |||
| type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) | |||
| type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool) | |||
| type_except = TypeError if type_mismatch else ValueError | |||
| if type_mismatch or not rel_fn(arg_value, value): | |||
| rel_str = Rel.get_strs(rel).format(value) | |||
| arg_name = arg_name if arg_name else "parameter" | |||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | |||
| raise type_except(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`' | |||
| raise type_except(f'{msg_prefix} `{arg_name}` should be an {arg_type} and must {rel_str}, but got `{arg_value}`' | |||
| f' with type `{type(arg_value).__name__}`.') | |||
| return arg_value | |||
| @@ -149,7 +149,7 @@ class Validator: | |||
| - number = check_positive_int(number) | |||
| - number = check_positive_int(number, "bias") | |||
| """ | |||
| return _check_integer(arg_value, 0, Rel.GT, arg_name, prim_name) | |||
| return check_number(arg_value, 0, Rel.GT, int, arg_name, prim_name) | |||
| @staticmethod | |||
| def check_negative_int(arg_value, arg_name=None, prim_name=None): | |||
| @@ -160,7 +160,7 @@ class Validator: | |||
| - number = check_negative_int(number) | |||
| - number = check_negative_int(number, "bias") | |||
| """ | |||
| return _check_integer(arg_value, 0, Rel.LT, arg_name, prim_name) | |||
| return check_number(arg_value, 0, Rel.LT, int, arg_name, prim_name) | |||
| @staticmethod | |||
| def check_non_positive_int(arg_value, arg_name=None, prim_name=None): | |||
| @@ -171,7 +171,7 @@ class Validator: | |||
| - number = check_non_positive_int(number) | |||
| - number = check_non_positive_int(number, "bias") | |||
| """ | |||
| return _check_integer(arg_value, 0, Rel.LE, arg_name, prim_name) | |||
| return check_number(arg_value, 0, Rel.LE, int, arg_name, prim_name) | |||
| @staticmethod | |||
| def check_non_negative_int(arg_value, arg_name=None, prim_name=None): | |||
| @@ -182,7 +182,52 @@ class Validator: | |||
| - number = check_non_negative_int(number) | |||
| - number = check_non_negative_int(number, "bias") | |||
| """ | |||
| return _check_integer(arg_value, 0, Rel.GE, arg_name, prim_name) | |||
| return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name) | |||
| @staticmethod | |||
| def check_positive_float(arg_value, arg_name=None, prim_name=None): | |||
| """ | |||
| Check argument is positive float, which mean arg_value > 0. | |||
| Usage: | |||
| - number = check_positive_float(number) | |||
| - number = check_positive_float(number, "bias") | |||
| - number = check_positive_float(number, "bias", "bias_class") | |||
| """ | |||
| return check_number(arg_value, 0, Rel.GT, float, arg_name, prim_name) | |||
| @staticmethod | |||
| def check_negative_float(arg_value, arg_name=None, prim_name=None): | |||
| """ | |||
| Check argument is negative float, which mean arg_value < 0. | |||
| Usage: | |||
| - number = check_negative_float(number) | |||
| - number = check_negative_float(number, "bias") | |||
| """ | |||
| return check_number(arg_value, 0, Rel.LT, float, arg_name, prim_name) | |||
| @staticmethod | |||
| def check_non_positive_float(arg_value, arg_name=None, prim_name=None): | |||
| """ | |||
| Check argument is non-negative float, which mean arg_value <= 0. | |||
| Usage: | |||
| - number = check_non_positive_float(number) | |||
| - number = check_non_positive_float(number, "bias") | |||
| """ | |||
| return check_number(arg_value, 0, Rel.LE, float, arg_name, prim_name) | |||
| @staticmethod | |||
| def check_non_negative_float(arg_value, arg_name=None, prim_name=None): | |||
| """ | |||
| Check argument is non-negative float, which mean arg_value >= 0. | |||
| Usage: | |||
| - number = check_non_negative_float(number) | |||
| - number = check_non_negative_float(number, "bias") | |||
| """ | |||
| return check_number(arg_value, 0, Rel.GE, float, arg_name, prim_name) | |||
| @staticmethod | |||
| def check_number(arg_name, arg_value, value, rel, prim_name): | |||
| @@ -257,16 +302,6 @@ class Validator: | |||
| raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.") | |||
| return padding | |||
| @staticmethod | |||
| def check_float_positive(arg_name, arg_value, prim_name): | |||
| """Float type judgment.""" | |||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | |||
| if isinstance(arg_value, float): | |||
| if arg_value > 0: | |||
| return arg_value | |||
| raise ValueError(f"{msg_prefix} `{arg_name}` must be positive, but got {arg_value}.") | |||
| raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") | |||
| @staticmethod | |||
| def check_subclass(arg_name, type_, template_types, prim_name): | |||
| """Checks whether some type is subclass of another type""" | |||
| @@ -82,12 +82,6 @@ def check_positive(value, arg_name=""): | |||
| raise ValueError("Input {0}must be greater than 0.".format(arg_name)) | |||
| def check_positive_float(value, arg_name=""): | |||
| arg_name = pad_arg_name(arg_name) | |||
| type_check(value, (float,), arg_name) | |||
| check_positive(value, arg_name) | |||
| def check_2tuple(value, arg_name=""): | |||
| if not (isinstance(value, tuple) and len(value) == 2): | |||
| raise ValueError("Value {0}needs to be a 2-tuple.".format(arg_name)) | |||
| @@ -66,9 +66,9 @@ def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_e | |||
| validator.check_positive_int(total_step, 'total_step') | |||
| validator.check_positive_int(step_per_epoch, 'step_per_epoch') | |||
| validator.check_positive_int(decay_epoch, 'decay_epoch') | |||
| validator.check_float_positive('learning_rate', learning_rate, None) | |||
| validator.check_positive_float(learning_rate, 'learning_rate') | |||
| validator.check_float_legal_value('learning_rate', learning_rate, None) | |||
| validator.check_float_positive('decay_rate', decay_rate, None) | |||
| validator.check_positive_float(decay_rate, 'decay_rate') | |||
| validator.check_float_legal_value('decay_rate', decay_rate, None) | |||
| validator.check_value_type('is_stair', is_stair, [bool], None) | |||
| @@ -234,7 +234,7 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): | |||
| if not isinstance(min_lr, float): | |||
| raise TypeError("min_lr must be float.") | |||
| validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None) | |||
| validator.check_float_positive('max_lr', max_lr, None) | |||
| validator.check_positive_float(max_lr, 'max_lr') | |||
| validator.check_float_legal_value('max_lr', max_lr, None) | |||
| validator.check_positive_int(total_step, 'total_step') | |||
| validator.check_positive_int(step_per_epoch, 'step_per_epoch') | |||
| @@ -299,12 +299,12 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e | |||
| >>> polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power) | |||
| [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01] | |||
| """ | |||
| validator.check_float_positive('learning_rate', learning_rate, None) | |||
| validator.check_positive_float(learning_rate, 'learning_rate') | |||
| validator.check_float_legal_value('learning_rate', learning_rate, None) | |||
| if not isinstance(end_learning_rate, float): | |||
| raise TypeError("end_learning_rate must be float.") | |||
| validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) | |||
| validator.check_float_positive('power', power, None) | |||
| validator.check_positive_float(power, 'power') | |||
| validator.check_float_legal_value('power', power, None) | |||
| validator.check_positive_int(total_step, 'total_step') | |||
| validator.check_positive_int(step_per_epoch, 'step_per_epoch') | |||
| @@ -221,7 +221,7 @@ class SSIM(Cell): | |||
| validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) | |||
| self.max_val = max_val | |||
| self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) | |||
| self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name) | |||
| self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name) | |||
| self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) | |||
| self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) | |||
| window = _create_window(filter_size, filter_sigma) | |||
| @@ -299,7 +299,7 @@ class MSSSIM(Cell): | |||
| self.max_val = max_val | |||
| validator.check_value_type('power_factors', power_factors, [tuple, list], self.cls_name) | |||
| self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) | |||
| self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name) | |||
| self.filter_sigma = validator.check_positive_float(filter_sigma, 'filter_sigma', self.cls_name) | |||
| self.k1 = validator.check_value_type('k1', k1, [float], self.cls_name) | |||
| self.k2 = validator.check_value_type('k2', k2, [float], self.cls_name) | |||
| window = _create_window(filter_size, filter_sigma) | |||
| @@ -45,9 +45,9 @@ class LearningRateSchedule(Cell): | |||
| def _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, cls_name): | |||
| validator.check_positive_int(decay_steps, 'decay_steps', cls_name) | |||
| validator.check_float_positive('learning_rate', learning_rate, cls_name) | |||
| validator.check_positive_float(learning_rate, 'learning_rate', cls_name) | |||
| validator.check_float_legal_value('learning_rate', learning_rate, cls_name) | |||
| validator.check_float_positive('decay_rate', decay_rate, cls_name) | |||
| validator.check_positive_float(decay_rate, 'decay_rate', cls_name) | |||
| validator.check_float_legal_value('decay_rate', decay_rate, cls_name) | |||
| validator.check_value_type('is_stair', is_stair, [bool], cls_name) | |||
| @@ -255,7 +255,7 @@ class CosineDecayLR(LearningRateSchedule): | |||
| if not isinstance(min_lr, float): | |||
| raise TypeError("min_lr must be float.") | |||
| validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||
| validator.check_float_positive('max_lr', max_lr, self.cls_name) | |||
| validator.check_positive_float(max_lr, 'max_lr', self.cls_name) | |||
| validator.check_float_legal_value('max_lr', max_lr, self.cls_name) | |||
| validator.check_positive_int(decay_steps, "decay_steps", self.cls_name) | |||
| if min_lr >= max_lr: | |||
| @@ -318,7 +318,7 @@ class PolynomialDecayLR(LearningRateSchedule): | |||
| """ | |||
| def __init__(self, learning_rate, end_learning_rate, decay_steps, power, update_decay_steps=False): | |||
| super(PolynomialDecayLR, self).__init__() | |||
| validator.check_float_positive('learning_rate', learning_rate, None) | |||
| validator.check_positive_float(learning_rate, 'learning_rate') | |||
| validator.check_float_legal_value('learning_rate', learning_rate, None) | |||
| if not isinstance(end_learning_rate, float): | |||
| raise TypeError("end_learning_rate must be float.") | |||
| @@ -326,7 +326,7 @@ class PolynomialDecayLR(LearningRateSchedule): | |||
| self.cls_name) | |||
| validator.check_positive_int(decay_steps, 'decay_steps', self.cls_name) | |||
| validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name) | |||
| validator.check_float_positive('power', power, self.cls_name) | |||
| validator.check_positive_float(power, 'power', self.cls_name) | |||
| validator.check_float_legal_value('power', power, self.cls_name) | |||
| self.decay_steps = decay_steps | |||
| @@ -503,7 +503,7 @@ class BatchNormFold(PrimitiveWithInfer): | |||
| def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): | |||
| """Initialize batch norm fold layer""" | |||
| self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) | |||
| self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) | |||
| self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | |||
| self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) | |||
| @@ -546,7 +546,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer): | |||
| """Initialize BatchNormGrad layer""" | |||
| self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | |||
| self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) | |||
| self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) | |||
| self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) | |||
| self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'], | |||
| outputs=['dx']) | |||
| @@ -814,7 +814,7 @@ class BatchNormFoldD(PrimitiveWithInfer): | |||
| """Initialize _BatchNormFold layer""" | |||
| from mindspore.ops._op_impl._custom_op import batchnorm_fold | |||
| self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) | |||
| self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) | |||
| self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | |||
| self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) | |||
| self.data_format = "NCHW" | |||
| @@ -842,7 +842,7 @@ class BatchNormFoldGradD(PrimitiveWithInfer): | |||
| def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0): | |||
| """Initialize _BatchNormFoldGrad layer""" | |||
| from mindspore.ops._op_impl._custom_op import batchnorm_fold_grad | |||
| self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) | |||
| self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) | |||
| self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | |||
| self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) | |||
| self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std'], | |||
| @@ -3560,7 +3560,7 @@ class IFMR(PrimitiveWithInfer): | |||
| validator.check_value_type("max_percentile", max_percentile, [float], self.name) | |||
| validator.check_value_type("search_range", search_range, [list, tuple], self.name) | |||
| for item in search_range: | |||
| validator.check_float_positive("item of search_range", item, self.name) | |||
| validator.check_positive_float(item, "item of search_range", self.name) | |||
| validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], Rel.GE, self.name) | |||
| validator.check_value_type("search_step", search_step, [float], self.name) | |||
| validator.check_value_type("offset_flag", with_offset, [bool], self.name) | |||