Merge pull request !7339 from qujianwei/mastertags/v1.1.0
| @@ -94,10 +94,10 @@ rel_strs = { | |||||
| def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None): | def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None): | ||||
| """ | """ | ||||
| Check argument integer. | |||||
| Check argument integer. | |||||
| Usage: | |||||
| - number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0 | |||||
| Usage: | |||||
| - number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0 | |||||
| """ | """ | ||||
| rel_fn = Rel.get_fns(rel) | rel_fn = Rel.get_fns(rel) | ||||
| type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool) | type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool) | ||||
| @@ -122,13 +122,33 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None): | |||||
| """ | """ | ||||
| prim_name = f'in \'{prim_name}\'' if prim_name else '' | prim_name = f'in \'{prim_name}\'' if prim_name else '' | ||||
| arg_name = f'\'{prim_name}\'' if arg_name else 'Input value' | arg_name = f'\'{prim_name}\'' if arg_name else 'Input value' | ||||
| if isinstance(arg_value, arg_type): | |||||
| if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool): | |||||
| if math.isinf(arg_value) or math.isnan(arg_value): | if math.isinf(arg_value) or math.isnan(arg_value): | ||||
| raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.') | raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.') | ||||
| return arg_value | return arg_value | ||||
| raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`') | raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`') | ||||
| def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None): | |||||
| """ | |||||
| Method for checking whether an int value is in some range. | |||||
| Usage: | |||||
| - number = check_number_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number", float) # number in [0.0, 1.0] | |||||
| - number = check_number_range(number, 0, 1, Rel.INC_NEITHER, "number", int) # number in [0, 1] | |||||
| """ | |||||
| prim_name = f'in `{prim_name}`' if prim_name else '' | |||||
| arg_name = f'`{arg_name}`' if arg_name else '' | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool) | |||||
| excp_cls = TypeError if type_mismatch else ValueError | |||||
| if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit): | |||||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||||
| raise excp_cls("{} {} should be in range of {}, but got {:.3f} with type {}.".format( | |||||
| arg_name, prim_name, rel_str, arg_value, type(arg_value).__name__)) | |||||
| return arg_value | |||||
| class Validator: | class Validator: | ||||
| """validator for checking input parameters""" | """validator for checking input parameters""" | ||||
| @@ -147,16 +167,13 @@ class Validator: | |||||
| @staticmethod | @staticmethod | ||||
| def check_integer(arg_name, arg_value, value, rel, prim_name=None): | def check_integer(arg_name, arg_value, value, rel, prim_name=None): | ||||
| """Check argument is integer""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) | |||||
| excp_cls = TypeError if type_mismatch else ValueError | |||||
| if type_mismatch or not rel_fn(arg_value, value): | |||||
| rel_str = Rel.get_strs(rel).format(value) | |||||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | |||||
| raise excp_cls(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`' | |||||
| f' with type `{type(arg_value).__name__}`.') | |||||
| return arg_value | |||||
| """ | |||||
| Checks input integer value `arg_value` compare to `value`. | |||||
| Usage: | |||||
| - number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0 | |||||
| """ | |||||
| return check_number(arg_value, value, rel, int, arg_name, prim_name) | |||||
| @staticmethod | @staticmethod | ||||
| def check_is_int(arg_value, arg_name=None, prim_name=None): | def check_is_int(arg_value, arg_name=None, prim_name=None): | ||||
| @@ -168,7 +185,7 @@ class Validator: | |||||
| - number = check_is_int(number, int, "bias") | - number = check_is_int(number, int, "bias") | ||||
| - number = check_is_int(number, int, "bias", "bias_class") | - number = check_is_int(number, int, "bias", "bias_class") | ||||
| """ | """ | ||||
| check_is_number(arg_value, int, arg_name, prim_name) | |||||
| return check_is_number(arg_value, int, arg_name, prim_name) | |||||
| @staticmethod | @staticmethod | ||||
| def check_positive_int(arg_value, arg_name=None, prim_name=None): | def check_positive_int(arg_value, arg_name=None, prim_name=None): | ||||
| @@ -214,6 +231,16 @@ class Validator: | |||||
| """ | """ | ||||
| return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name) | return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name) | ||||
| @staticmethod | |||||
| def check_float(arg_value, value, rel, arg_name=None, prim_name=None): | |||||
| """ | |||||
| Checks input float value `arg_value` compare to `value`. | |||||
| Usage: | |||||
| - number = check_float(number, 0.0, Rel.GE, "number", None) # number >= 0 | |||||
| """ | |||||
| return check_number(arg_value, value, rel, float, arg_name, prim_name) | |||||
| @staticmethod | @staticmethod | ||||
| def check_is_float(arg_value, arg_name=None, prim_name=None): | def check_is_float(arg_value, arg_name=None, prim_name=None): | ||||
| """ | """ | ||||
| @@ -224,7 +251,7 @@ class Validator: | |||||
| - number = check_is_float(number, int, "bias") | - number = check_is_float(number, int, "bias") | ||||
| - number = check_is_float(number, int, "bias", "bias_class") | - number = check_is_float(number, int, "bias", "bias_class") | ||||
| """ | """ | ||||
| check_is_number(arg_value, float, arg_name, prim_name) | |||||
| return check_is_number(arg_value, float, arg_name, prim_name) | |||||
| @staticmethod | @staticmethod | ||||
| def check_positive_float(arg_value, arg_name=None, prim_name=None): | def check_positive_float(arg_value, arg_name=None, prim_name=None): | ||||
| @@ -302,25 +329,26 @@ class Validator: | |||||
| return arg_value | return arg_value | ||||
| @staticmethod | @staticmethod | ||||
| def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): | |||||
| """Method for checking whether an int value is in some range.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) | |||||
| excp_cls = TypeError if type_mismatch else ValueError | |||||
| if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit): | |||||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||||
| raise excp_cls(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},' | |||||
| f' but got `{arg_value}` with type `{type(arg_value).__name__}`.') | |||||
| return arg_value | |||||
| def check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None): | |||||
| """ | |||||
| Method for checking whether input value is in int range. | |||||
| Usage: | |||||
| - number = check_int_range(number, 0, 1, Rel.INC_NEITHER) # number in [0, 1] | |||||
| - number = check_int_range(number, 0, 1, Rel.INC_NEITHER, "number") # number in [0, 1] | |||||
| """ | |||||
| return check_number_range(arg_value, lower_limit, upper_limit, rel, int, arg_name, prim_name) | |||||
| @staticmethod | @staticmethod | ||||
| def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): | |||||
| """Method for checking whether a numeric value is in some range.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| if not rel_fn(arg_value, lower_limit, upper_limit): | |||||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.') | |||||
| return arg_value | |||||
| def check_float_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None): | |||||
| """ | |||||
| Method for checking whether input value is in float range. | |||||
| Usage: | |||||
| - number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER) # number in [0.0, 1.0] | |||||
| - number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number") # number in [0.0, 1.0] | |||||
| """ | |||||
| return check_number_range(arg_value, lower_limit, upper_limit, rel, float, arg_name, prim_name) | |||||
| @staticmethod | @staticmethod | ||||
| def check_string(arg_value, valid_values, arg_name=None, prim_name=None): | def check_string(arg_value, valid_values, arg_name=None, prim_name=None): | ||||
| @@ -502,13 +530,6 @@ class Validator: | |||||
| f'{tuple(exp_shape)}, but got {shape}.') | f'{tuple(exp_shape)}, but got {shape}.') | ||||
| def check_int(input_param): | |||||
| """Int type judgment.""" | |||||
| if isinstance(input_param, int) and not isinstance(input_param, bool): | |||||
| return input_param | |||||
| raise TypeError("Input type must be int!") | |||||
| def check_int_zero_one(input_param): | def check_int_zero_one(input_param): | ||||
| """Judge whether it is 0 or 1.""" | """Judge whether it is 0 or 1.""" | ||||
| if input_param in (0, 1): | if input_param in (0, 1): | ||||
| @@ -233,7 +233,7 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): | |||||
| """ | """ | ||||
| if not isinstance(min_lr, float): | if not isinstance(min_lr, float): | ||||
| raise TypeError("min_lr must be float.") | raise TypeError("min_lr must be float.") | ||||
| validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None) | |||||
| validator.check_non_negative_float(min_lr, "min_lr", None) | |||||
| validator.check_positive_float(max_lr, 'max_lr') | validator.check_positive_float(max_lr, 'max_lr') | ||||
| validator.check_is_float(max_lr, 'max_lr') | validator.check_is_float(max_lr, 'max_lr') | ||||
| validator.check_positive_int(total_step, 'total_step') | validator.check_positive_int(total_step, 'total_step') | ||||
| @@ -303,7 +303,7 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e | |||||
| validator.check_is_float(learning_rate, 'learning_rate') | validator.check_is_float(learning_rate, 'learning_rate') | ||||
| if not isinstance(end_learning_rate, float): | if not isinstance(end_learning_rate, float): | ||||
| raise TypeError("end_learning_rate must be float.") | raise TypeError("end_learning_rate must be float.") | ||||
| validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) | |||||
| validator.check_non_negative_float(end_learning_rate, "end_learning_rate", None) | |||||
| validator.check_positive_float(power, 'power') | validator.check_positive_float(power, 'power') | ||||
| validator.check_is_float(power, 'power') | validator.check_is_float(power, 'power') | ||||
| validator.check_positive_int(total_step, 'total_step') | validator.check_positive_int(total_step, 'total_step') | ||||
| @@ -356,7 +356,7 @@ def warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch): | |||||
| """ | """ | ||||
| if not isinstance(learning_rate, float): | if not isinstance(learning_rate, float): | ||||
| raise TypeError("learning_rate must be float.") | raise TypeError("learning_rate must be float.") | ||||
| validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) | |||||
| validator.check_non_negative_float(learning_rate, "learning_rate", None) | |||||
| validator.check_positive_int(warmup_epoch, 'warmup_epoch') | validator.check_positive_int(warmup_epoch, 'warmup_epoch') | ||||
| validator.check_positive_int(total_step, 'total_step') | validator.check_positive_int(total_step, 'total_step') | ||||
| validator.check_positive_int(step_per_epoch, 'step_per_epoch') | validator.check_positive_int(step_per_epoch, 'step_per_epoch') | ||||
| @@ -451,8 +451,7 @@ class CentralCrop(Cell): | |||||
| def __init__(self, central_fraction): | def __init__(self, central_fraction): | ||||
| super(CentralCrop, self).__init__() | super(CentralCrop, self).__init__() | ||||
| validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name) | validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name) | ||||
| self.central_fraction = validator.check_number_range('central_fraction', central_fraction, | |||||
| 0.0, 1.0, Rel.INC_RIGHT, self.cls_name) | |||||
| self.central_fraction = validator.check_float_range(0.0, 1.0, Rel.INC_RIGHT, 'central_fraction', central_fraction, self.cls_name) | |||||
| self.slice = P.Slice() | self.slice = P.Slice() | ||||
| def construct(self, image): | def construct(self, image): | ||||
| @@ -254,7 +254,7 @@ class CosineDecayLR(LearningRateSchedule): | |||||
| super(CosineDecayLR, self).__init__() | super(CosineDecayLR, self).__init__() | ||||
| if not isinstance(min_lr, float): | if not isinstance(min_lr, float): | ||||
| raise TypeError("min_lr must be float.") | raise TypeError("min_lr must be float.") | ||||
| validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||||
| validator.check_non_negative_float(min_lr, "min_lr", self.cls_name) | |||||
| validator.check_positive_float(max_lr, 'max_lr', self.cls_name) | validator.check_positive_float(max_lr, 'max_lr', self.cls_name) | ||||
| validator.check_is_float(max_lr, 'max_lr', self.cls_name) | validator.check_is_float(max_lr, 'max_lr', self.cls_name) | ||||
| validator.check_positive_int(decay_steps, "decay_steps", self.cls_name) | validator.check_positive_int(decay_steps, "decay_steps", self.cls_name) | ||||
| @@ -322,8 +322,7 @@ class PolynomialDecayLR(LearningRateSchedule): | |||||
| validator.check_is_float(learning_rate, 'learning_rate') | validator.check_is_float(learning_rate, 'learning_rate') | ||||
| if not isinstance(end_learning_rate, float): | if not isinstance(end_learning_rate, float): | ||||
| raise TypeError("end_learning_rate must be float.") | raise TypeError("end_learning_rate must be float.") | ||||
| validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, | |||||
| self.cls_name) | |||||
| validator.check_non_negative_float(end_learning_rate, "end_learning_rate", self.cls_name) | |||||
| validator.check_positive_int(decay_steps, 'decay_steps', self.cls_name) | validator.check_positive_int(decay_steps, 'decay_steps', self.cls_name) | ||||
| validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name) | validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name) | ||||
| validator.check_positive_float(power, 'power', self.cls_name) | validator.check_positive_float(power, 'power', self.cls_name) | ||||
| @@ -387,7 +386,7 @@ class WarmUpLR(LearningRateSchedule): | |||||
| super(WarmUpLR, self).__init__() | super(WarmUpLR, self).__init__() | ||||
| if not isinstance(learning_rate, float): | if not isinstance(learning_rate, float): | ||||
| raise TypeError("learning_rate must be float.") | raise TypeError("learning_rate must be float.") | ||||
| validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||||
| validator.check_non_negative_float(learning_rate, "learning_rate", self.cls_name) | |||||
| validator.check_positive_int(warmup_steps, 'warmup_steps', self.cls_name) | validator.check_positive_int(warmup_steps, 'warmup_steps', self.cls_name) | ||||
| self.warmup_steps = warmup_steps | self.warmup_steps = warmup_steps | ||||
| self.learning_rate = learning_rate | self.learning_rate = learning_rate | ||||
| @@ -368,7 +368,7 @@ class CosineEmbeddingLoss(_Loss): | |||||
| self.reduce_sum = P.ReduceSum() | self.reduce_sum = P.ReduceSum() | ||||
| self.maximum = P.Maximum() | self.maximum = P.Maximum() | ||||
| validator.check_value_type("margin", margin, [float], self.cls_name) | validator.check_value_type("margin", margin, [float], self.cls_name) | ||||
| self.margin = validator.check_number_range("margin", margin, -1.0, 1.0, Rel.INC_BOTH, self.cls_name) | |||||
| self.margin = validator.check_float_range(margin, -1.0, 1.0, Rel.INC_BOTH, "margin", self.cls_name) | |||||
| def construct(self, x1, x2, y): | def construct(self, x1, x2, y): | ||||
| F.same_type_shape(x1, x2) | F.same_type_shape(x1, x2) | ||||
| @@ -126,9 +126,9 @@ def _check_param_value(beta1, beta2, eps, prim_name): | |||||
| validator.check_value_type("beta1", beta1, [float], prim_name) | validator.check_value_type("beta1", beta1, [float], prim_name) | ||||
| validator.check_value_type("beta2", beta2, [float], prim_name) | validator.check_value_type("beta2", beta2, [float], prim_name) | ||||
| validator.check_value_type("eps", eps, [float], prim_name) | validator.check_value_type("eps", eps, [float], prim_name) | ||||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) | |||||
| validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) | |||||
| validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) | |||||
| validator.check_positive_float(eps, "eps", prim_name) | |||||
| class Adam(Optimizer): | class Adam(Optimizer): | ||||
| @@ -177,9 +177,9 @@ def _check_param_value(beta1, beta2, eps, prim_name): | |||||
| validator.check_value_type("beta1", beta1, [float], prim_name) | validator.check_value_type("beta1", beta1, [float], prim_name) | ||||
| validator.check_value_type("beta2", beta2, [float], prim_name) | validator.check_value_type("beta2", beta2, [float], prim_name) | ||||
| validator.check_value_type("eps", eps, [float], prim_name) | validator.check_value_type("eps", eps, [float], prim_name) | ||||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) | |||||
| validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) | |||||
| validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) | |||||
| validator.check_positive_float(eps, "eps", prim_name) | |||||
| class Lamb(Optimizer): | class Lamb(Optimizer): | ||||
| @@ -70,10 +70,10 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): | |||||
| validator.check_value_type("beta2", beta2, [float], prim_name) | validator.check_value_type("beta2", beta2, [float], prim_name) | ||||
| validator.check_value_type("eps", eps, [float], prim_name) | validator.check_value_type("eps", eps, [float], prim_name) | ||||
| validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) | validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) | ||||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||||
| validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) | |||||
| validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) | |||||
| validator.check_positive_float(eps, "eps", prim_name) | |||||
| validator.check_non_negative_float(weight_decay, "weight_decay", prim_name) | |||||
| class LazyAdam(Optimizer): | class LazyAdam(Optimizer): | ||||
| @@ -100,7 +100,7 @@ class Optimizer(Cell): | |||||
| if isinstance(loss_scale, int): | if isinstance(loss_scale, int): | ||||
| loss_scale = float(loss_scale) | loss_scale = float(loss_scale) | ||||
| validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) | validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) | ||||
| validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name) | |||||
| validator.check_positive_float(loss_scale, "loss_scale", self.cls_name) | |||||
| self.loss_scale = loss_scale | self.loss_scale = loss_scale | ||||
| weight_decay = self._preprocess_weight_decay(weight_decay) | weight_decay = self._preprocess_weight_decay(weight_decay) | ||||
| @@ -221,7 +221,7 @@ class Optimizer(Cell): | |||||
| """Check weight decay, and convert int to float.""" | """Check weight decay, and convert int to float.""" | ||||
| if isinstance(weight_decay, (float, int)): | if isinstance(weight_decay, (float, int)): | ||||
| weight_decay = float(weight_decay) | weight_decay = float(weight_decay) | ||||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||||
| validator.check_non_negative_float(weight_decay, "weight_decay", self.cls_name) | |||||
| return weight_decay | return weight_decay | ||||
| raise TypeError("Weight decay should be int or float.") | raise TypeError("Weight decay should be int or float.") | ||||
| @@ -229,7 +229,7 @@ class Optimizer(Cell): | |||||
| """Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule.""" | """Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule.""" | ||||
| if isinstance(learning_rate, (float, int)): | if isinstance(learning_rate, (float, int)): | ||||
| learning_rate = float(learning_rate) | learning_rate = float(learning_rate) | ||||
| validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||||
| validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name) | |||||
| return learning_rate | return learning_rate | ||||
| if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0: | if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0: | ||||
| return learning_rate | return learning_rate | ||||
| @@ -45,9 +45,9 @@ def _check_param_value(accum, l1, l2, use_locking, prim_name=None): | |||||
| validator.check_value_type("l1", l1, [float], prim_name) | validator.check_value_type("l1", l1, [float], prim_name) | ||||
| validator.check_value_type("l2", l2, [float], prim_name) | validator.check_value_type("l2", l2, [float], prim_name) | ||||
| validator.check_value_type("use_locking", use_locking, [bool], prim_name) | validator.check_value_type("use_locking", use_locking, [bool], prim_name) | ||||
| validator.check_number_range("accum", accum, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||||
| validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||||
| validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||||
| validator.check_non_negative_float(accum, "accum", prim_name) | |||||
| validator.check_non_negative_float(l1, "l1", prim_name) | |||||
| validator.check_non_negative_float(l2, "l2", prim_name) | |||||
| class ProximalAdagrad(Optimizer): | class ProximalAdagrad(Optimizer): | ||||
| @@ -154,11 +154,11 @@ class RMSProp(Optimizer): | |||||
| use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0): | use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0): | ||||
| super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale) | super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale) | ||||
| validator.check_value_type("decay", decay, [float], self.cls_name) | validator.check_value_type("decay", decay, [float], self.cls_name) | ||||
| validator.check_number_range("decay", decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||||
| validator.check_non_negative_float(decay, "decay", self.cls_name) | |||||
| validator.check_value_type("momentum", momentum, [float], self.cls_name) | validator.check_value_type("momentum", momentum, [float], self.cls_name) | ||||
| validator.check_number_range("momentum", momentum, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||||
| validator.check_non_negative_float(momentum, "momentum", self.cls_name) | |||||
| validator.check_value_type("epsilon", epsilon, [float], self.cls_name) | validator.check_value_type("epsilon", epsilon, [float], self.cls_name) | ||||
| validator.check_number_range("epsilon", epsilon, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name) | |||||
| validator.check_positive_float(epsilon, "epsilon", self.cls_name) | |||||
| validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) | validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) | ||||
| validator.check_value_type("centered", centered, [bool], self.cls_name) | validator.check_value_type("centered", centered, [bool], self.cls_name) | ||||
| @@ -69,7 +69,7 @@ def get_concat_offset(x_shp, x_type, axis, prim_name): | |||||
| validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name) | validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name) | ||||
| validator.check_positive_int(len(x_shp[0]), "len of x_shp[0]", prim_name) | validator.check_positive_int(len(x_shp[0]), "len of x_shp[0]", prim_name) | ||||
| rank_base = len(x_shp[0]) | rank_base = len(x_shp[0]) | ||||
| validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name) | |||||
| validator.check_int_range(axis, -rank_base - 1, rank_base, Rel.INC_BOTH, 'axis', prim_name) | |||||
| if axis < 0: | if axis < 0: | ||||
| axis = axis + rank_base | axis = axis + rank_base | ||||
| all_shp = x_shp[0][axis] | all_shp = x_shp[0][axis] | ||||
| @@ -188,7 +188,7 @@ class BatchNormGrad(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, is_training=False, epsilon=1e-5): | def __init__(self, is_training=False, epsilon=1e-5): | ||||
| self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | ||||
| self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) | |||||
| self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||||
| self.add_prim_attr('data_format', "NCHW") | self.add_prim_attr('data_format', "NCHW") | ||||
| def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape): | def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape): | ||||
| @@ -485,7 +485,7 @@ class DropoutGrad(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, keep_prob=0.5): | def __init__(self, keep_prob=0.5): | ||||
| self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name) | |||||
| self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name) | |||||
| def infer_shape(self, dy_shape, mask_shape): | def infer_shape(self, dy_shape, mask_shape): | ||||
| return dy_shape | return dy_shape | ||||
| @@ -902,7 +902,7 @@ class LogSoftmaxGrad(PrimitiveWithInfer): | |||||
| def infer_shape(self, dout, logits): | def infer_shape(self, dout, logits): | ||||
| rank = len(logits) | rank = len(logits) | ||||
| validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH, self.name) | |||||
| validator.check_int_range(self.axis, -rank - 1, rank, Rel.INC_BOTH, 'axis', self.name) | |||||
| return logits | return logits | ||||
| def infer_dtype(self, dout, logits): | def infer_dtype(self, dout, logits): | ||||
| @@ -921,7 +921,7 @@ class LSTMGradData(PrimitiveWithInfer): | |||||
| self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) | self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) | ||||
| self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) | self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) | ||||
| self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) | self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) | ||||
| self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name) | |||||
| if bidirectional: | if bidirectional: | ||||
| self.num_directions = 2 | self.num_directions = 2 | ||||
| @@ -970,7 +970,7 @@ class LSTMGradWeight(PrimitiveWithInfer): | |||||
| self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) | self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) | ||||
| self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) | self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) | ||||
| self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) | self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) | ||||
| self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name) | |||||
| if bidirectional: | if bidirectional: | ||||
| self.num_directions = 2 | self.num_directions = 2 | ||||
| @@ -1005,7 +1005,7 @@ class LSTMGrad(PrimitiveWithInfer): | |||||
| self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) | self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) | ||||
| self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) | self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) | ||||
| self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) | self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) | ||||
| self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name) | |||||
| if bidirectional: | if bidirectional: | ||||
| self.num_directions = 2 | self.num_directions = 2 | ||||
| @@ -1652,7 +1652,7 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, keep_prob): | def __init__(self, keep_prob): | ||||
| self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | ||||
| self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name) | |||||
| self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name) | |||||
| self.add_prim_attr("io_format", "ND") | self.add_prim_attr("io_format", "ND") | ||||
| def infer_shape(self, dgate_shape, w_shape): | def infer_shape(self, dgate_shape, w_shape): | ||||
| @@ -76,8 +76,7 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer): | |||||
| f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | ||||
| self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | ||||
| self.ema_decay = validator.check_number_range( | |||||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name) | |||||
| self.init_prim_io_names(inputs=['x', 'min', 'max'], | self.init_prim_io_names(inputs=['x', 'min', 'max'], | ||||
| outputs=['min_up', 'max_up']) | outputs=['min_up', 'max_up']) | ||||
| @@ -136,10 +135,9 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer): | |||||
| f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | ||||
| self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | ||||
| self.ema_decay = validator.check_number_range( | |||||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name) | |||||
| if self.is_ascend: | if self.is_ascend: | ||||
| self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name) | |||||
| else: | else: | ||||
| self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name) | self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name) | ||||
| self.init_prim_io_names( | self.init_prim_io_names( | ||||
| @@ -222,10 +220,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer): | |||||
| 'symmetric', symmetric, (bool,), self.name) | 'symmetric', symmetric, (bool,), self.name) | ||||
| self.narrow_range = validator.check_value_type( | self.narrow_range = validator.check_value_type( | ||||
| 'narrow_range', narrow_range, (bool,), self.name) | 'narrow_range', narrow_range, (bool,), self.name) | ||||
| self.training = validator.check_value_type( | |||||
| 'training', training, (bool,), self.name) | |||||
| self.ema_decay = validator.check_number_range( | |||||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.training = validator.check_value_type('training', training, (bool,), self.name) | |||||
| self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name) | |||||
| self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) | self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) | ||||
| self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name) | self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name) | ||||
| self.init_prim_io_names(inputs=['x', 'min', 'max'], | self.init_prim_io_names(inputs=['x', 'min', 'max'], | ||||
| @@ -366,12 +362,11 @@ class FakeQuantPerChannel(PrimitiveWithInfer): | |||||
| 'narrow_range', narrow_range, (bool,), self.name) | 'narrow_range', narrow_range, (bool,), self.name) | ||||
| self.training = validator.check_value_type( | self.training = validator.check_value_type( | ||||
| 'training', training, (bool,), self.name) | 'training', training, (bool,), self.name) | ||||
| self.ema_decay = validator.check_number_range( | |||||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name) | |||||
| self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) | self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) | ||||
| self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name) | self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name) | ||||
| if self.is_ascend: | if self.is_ascend: | ||||
| self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name) | |||||
| else: | else: | ||||
| self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name) | self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name) | ||||
| self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) | self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) | ||||
| @@ -495,7 +490,7 @@ class BatchNormFold(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): | def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): | ||||
| """Initialize batch norm fold layer""" | """Initialize batch norm fold layer""" | ||||
| self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) | |||||
| self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) | self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) | ||||
| self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | ||||
| self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) | self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) | ||||
| @@ -806,7 +801,7 @@ class BatchNormFoldD(PrimitiveWithInfer): | |||||
| def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): | def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): | ||||
| """Initialize _BatchNormFold layer""" | """Initialize _BatchNormFold layer""" | ||||
| from mindspore.ops._op_impl._custom_op import batchnorm_fold | from mindspore.ops._op_impl._custom_op import batchnorm_fold | ||||
| self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) | |||||
| self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) | self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) | ||||
| self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | ||||
| self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) | self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) | ||||
| @@ -129,7 +129,7 @@ class ExpandDims(PrimitiveWithInfer): | |||||
| x_shape = list(x['shape']) | x_shape = list(x['shape']) | ||||
| axis_v = axis['value'] | axis_v = axis['value'] | ||||
| rank = len(x_shape) | rank = len(x_shape) | ||||
| validator.check_int_range('axis', axis_v, -rank - 1, rank, Rel.INC_BOTH, self.name) | |||||
| validator.check_int_range(axis_v, -rank - 1, rank, Rel.INC_BOTH, 'axis', self.name) | |||||
| value = None | value = None | ||||
| if x['value'] is not None: | if x['value'] is not None: | ||||
| value = x['value'].asnumpy() | value = x['value'].asnumpy() | ||||
| @@ -534,7 +534,7 @@ class Squeeze(PrimitiveWithInfer): | |||||
| ret = [d for d in x_shape if d != 1] | ret = [d for d in x_shape if d != 1] | ||||
| else: | else: | ||||
| for a in axis: | for a in axis: | ||||
| validator.check_int_range('axis or its elements', a, -ndim, ndim - 1, Rel.INC_BOTH, self.name) | |||||
| validator.check_int_range(a, -ndim, ndim - 1, Rel.INC_BOTH, 'axis or its elements', self.name) | |||||
| if x_shape[a] != 1: | if x_shape[a] != 1: | ||||
| raise ValueError('Cannot select an axis to squeeze out which has size not equal to one.') | raise ValueError('Cannot select an axis to squeeze out which has size not equal to one.') | ||||
| ret = [x_shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis)] | ret = [x_shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis)] | ||||
| @@ -658,7 +658,7 @@ class GatherV2(PrimitiveWithCheck): | |||||
| axis_v = axis['value'] | axis_v = axis['value'] | ||||
| params_shp = params['shape'] | params_shp = params['shape'] | ||||
| rank = len(params_shp) | rank = len(params_shp) | ||||
| validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) | |||||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | |||||
| if axis_v < 0: | if axis_v < 0: | ||||
| axis_v += rank | axis_v += rank | ||||
| @@ -777,7 +777,7 @@ class Split(PrimitiveWithInfer): | |||||
| validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | ||||
| x_shape = list(x['shape']) | x_shape = list(x['shape']) | ||||
| dim = len(x_shape) | dim = len(x_shape) | ||||
| validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) | |||||
| validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name) | |||||
| validator.check_positive_int(self.output_num, "output_num", self.name) | validator.check_positive_int(self.output_num, "output_num", self.name) | ||||
| output_valid_check = x_shape[self.axis] % self.output_num | output_valid_check = x_shape[self.axis] % self.output_num | ||||
| if output_valid_check != 0: | if output_valid_check != 0: | ||||
| @@ -1224,7 +1224,7 @@ class Argmax(PrimitiveWithInfer): | |||||
| if axis is None: | if axis is None: | ||||
| axis = 0 | axis = 0 | ||||
| x_rank = len(x_shape) | x_rank = len(x_shape) | ||||
| validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name) | |||||
| validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name) | |||||
| axis = axis + x_rank if axis < 0 else axis | axis = axis + x_rank if axis < 0 else axis | ||||
| ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis] | ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis] | ||||
| return ouput_shape | return ouput_shape | ||||
| @@ -1272,7 +1272,7 @@ class Argmin(PrimitiveWithInfer): | |||||
| if axis is None: | if axis is None: | ||||
| axis = 0 | axis = 0 | ||||
| x_rank = len(x_shape) | x_rank = len(x_shape) | ||||
| validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name) | |||||
| validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name) | |||||
| axis = axis + x_rank if axis < 0 else axis | axis = axis + x_rank if axis < 0 else axis | ||||
| ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis] | ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis] | ||||
| return ouput_shape | return ouput_shape | ||||
| @@ -1325,7 +1325,7 @@ class ArgMaxWithValue(PrimitiveWithInfer): | |||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| axis = self.axis | axis = self.axis | ||||
| x_rank = len(x_shape) | x_rank = len(x_shape) | ||||
| validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name) | |||||
| validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name) | |||||
| ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name) | ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name) | ||||
| return ouput_shape, ouput_shape | return ouput_shape, ouput_shape | ||||
| @@ -1377,7 +1377,7 @@ class ArgMinWithValue(PrimitiveWithInfer): | |||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| axis = self.axis | axis = self.axis | ||||
| x_rank = len(x_shape) | x_rank = len(x_shape) | ||||
| validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name) | |||||
| validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name) | |||||
| ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name) | ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name) | ||||
| return ouput_shape, ouput_shape | return ouput_shape, ouput_shape | ||||
| @@ -1760,7 +1760,7 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name): | |||||
| rank_base = len(x_shape[0]) | rank_base = len(x_shape[0]) | ||||
| N = len(x_shape) | N = len(x_shape) | ||||
| out_shape = x_shape[0] | out_shape = x_shape[0] | ||||
| validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name) | |||||
| validator.check_int_range(axis, -rank_base - 1, rank_base, Rel.INC_BOTH, 'axis', prim_name) | |||||
| if axis < 0: | if axis < 0: | ||||
| axis = axis + rank_base + 1 | axis = axis + rank_base + 1 | ||||
| for i in range(1, N): | for i in range(1, N): | ||||
| @@ -1863,7 +1863,7 @@ class Unpack(PrimitiveWithInfer): | |||||
| validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | ||||
| x_shape = list(x['shape']) | x_shape = list(x['shape']) | ||||
| dim = len(x_shape) | dim = len(x_shape) | ||||
| validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) | |||||
| validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name) | |||||
| if self.axis < 0: | if self.axis < 0: | ||||
| self.axis = self.axis + dim | self.axis = self.axis + dim | ||||
| output_num = x_shape[self.axis] | output_num = x_shape[self.axis] | ||||
| @@ -1965,7 +1965,7 @@ class ReverseV2(PrimitiveWithInfer): | |||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| dim = len(x_shape) | dim = len(x_shape) | ||||
| for i, each in enumerate(self.axis): | for i, each in enumerate(self.axis): | ||||
| validator.check_int_range(f'axis[{i}]', each, -dim, dim, Rel.INC_LEFT, self.name) | |||||
| validator.check_int_range(each, -dim, dim, Rel.INC_LEFT, f'axis[{i}]', self.name) | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| @@ -206,7 +206,7 @@ class _HostAllGather(PrimitiveWithInfer): | |||||
| validator.check_value_type('group', group, (tuple, list), self.name) | validator.check_value_type('group', group, (tuple, list), self.name) | ||||
| validator.check_integer("group size", len(group), 2, Rel.GE, self.name) | validator.check_integer("group size", len(group), 2, Rel.GE, self.name) | ||||
| for r in group: | for r in group: | ||||
| validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name) | |||||
| validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name) | |||||
| validator.check_value_type("rank_id", r, (int,), self.name) | validator.check_value_type("rank_id", r, (int,), self.name) | ||||
| self.group_size = len(group) | self.group_size = len(group) | ||||
| self.add_prim_attr('group', group) | self.add_prim_attr('group', group) | ||||
| @@ -315,7 +315,7 @@ class _HostReduceScatter(PrimitiveWithInfer): | |||||
| validator.check_value_type('group', group, (tuple, list), self.name) | validator.check_value_type('group', group, (tuple, list), self.name) | ||||
| validator.check_integer("group size", len(group), 2, Rel.GE, self.name) | validator.check_integer("group size", len(group), 2, Rel.GE, self.name) | ||||
| for r in group: | for r in group: | ||||
| validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name) | |||||
| validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name) | |||||
| validator.check_value_type("rank_id", r, (int,), self.name) | validator.check_value_type("rank_id", r, (int,), self.name) | ||||
| self.op = op | self.op = op | ||||
| self.group_size = len(group) | self.group_size = len(group) | ||||
| @@ -70,8 +70,7 @@ class ControlDepend(Primitive): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, depend_mode=0): | def __init__(self, depend_mode=0): | ||||
| """init""" | """init""" | ||||
| validator.check_int_range( | |||||
| "depend_mode", depend_mode, 0, 1, Rel.INC_BOTH, self.name) | |||||
| validator.check_int_range(depend_mode, 0, 1, Rel.INC_BOTH, "depend_mode", self.name) | |||||
| def __call__(self, src, dst): | def __call__(self, src, dst): | ||||
| return src | return src | ||||
| @@ -31,7 +31,7 @@ def _infer_shape_reduce(x, axis, keep_dims, prim_name): | |||||
| """Common infer for reduce operator""" | """Common infer for reduce operator""" | ||||
| def reduce_one_axis(one_axis): | def reduce_one_axis(one_axis): | ||||
| validator.check_int_range('axis', one_axis, -dim, dim, Rel.INC_LEFT, prim_name) | |||||
| validator.check_int_range(one_axis, -dim, dim, Rel.INC_LEFT, 'axis', prim_name) | |||||
| if one_axis < 0: | if one_axis < 0: | ||||
| one_axis += dim | one_axis += dim | ||||
| axis_reduce.add(one_axis) | axis_reduce.add(one_axis) | ||||
| @@ -149,7 +149,7 @@ class Softmax(PrimitiveWithInfer): | |||||
| validator.check_integer("length of axis", len(self.axis), 1, Rel.GE, self.name) | validator.check_integer("length of axis", len(self.axis), 1, Rel.GE, self.name) | ||||
| rank = len(logits) | rank = len(logits) | ||||
| for axis_v in self.axis: | for axis_v in self.axis: | ||||
| validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) | |||||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | |||||
| return logits | return logits | ||||
| def infer_dtype(self, logits): | def infer_dtype(self, logits): | ||||
| @@ -193,7 +193,7 @@ class LogSoftmax(PrimitiveWithInfer): | |||||
| def infer_shape(self, logits): | def infer_shape(self, logits): | ||||
| rank = len(logits) | rank = len(logits) | ||||
| validator.check_int_range('axis', self.axis, -rank, rank, Rel.INC_LEFT, self.name) | |||||
| validator.check_int_range(self.axis, -rank, rank, Rel.INC_LEFT, 'axis', self.name) | |||||
| return logits | return logits | ||||
| def infer_dtype(self, logits): | def infer_dtype(self, logits): | ||||
| @@ -637,8 +637,8 @@ class FusedBatchNorm(Primitive): | |||||
| self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], | self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], | ||||
| outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance']) | outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance']) | ||||
| self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) | self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) | ||||
| self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) | |||||
| self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||||
| self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) | |||||
| self._update_parameter = True | self._update_parameter = True | ||||
| @@ -710,8 +710,8 @@ class FusedBatchNormEx(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], | self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], | ||||
| outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve']) | outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve']) | ||||
| self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) | self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) | ||||
| self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) | |||||
| self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||||
| self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) | |||||
| self._update_parameter = True | self._update_parameter = True | ||||
| self.add_prim_attr('data_format', "NCHW") | self.add_prim_attr('data_format', "NCHW") | ||||
| @@ -818,8 +818,8 @@ class BNTrainingUpdate(PrimitiveWithInfer): | |||||
| validator.check_value_type("isRef", isRef, [bool], self.name) | validator.check_value_type("isRef", isRef, [bool], self.name) | ||||
| validator.check_value_type("epsilon", epsilon, [float], self.name) | validator.check_value_type("epsilon", epsilon, [float], self.name) | ||||
| validator.check_value_type("factor", factor, [float], self.name) | validator.check_value_type("factor", factor, [float], self.name) | ||||
| self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, 'BNTrainingUpdate') | |||||
| self.factor = validator.check_number_range('factor', factor, 0, 1, Rel.INC_BOTH, 'BNTrainingUpdate') | |||||
| self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', 'BNTrainingUpdate') | |||||
| self.factor = validator.check_float_range(factor, 0, 1, Rel.INC_BOTH, 'factor', 'BNTrainingUpdate') | |||||
| def infer_shape(self, x, sum, square_sum, scale, b, mean, variance): | def infer_shape(self, x, sum, square_sum, scale, b, mean, variance): | ||||
| validator.check_integer("x rank", len(x), 4, Rel.EQ, self.name) | validator.check_integer("x rank", len(x), 4, Rel.EQ, self.name) | ||||
| @@ -898,7 +898,7 @@ class BatchNorm(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, is_training=False, epsilon=1e-5): | def __init__(self, is_training=False, epsilon=1e-5): | ||||
| validator.check_value_type('is_training', is_training, (bool,), self.name) | validator.check_value_type('is_training', is_training, (bool,), self.name) | ||||
| validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) | |||||
| validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||||
| self.add_prim_attr('data_format', "NCHW") | self.add_prim_attr('data_format', "NCHW") | ||||
| self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], | self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], | ||||
| outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) | outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) | ||||
| @@ -2383,7 +2383,7 @@ class L2Normalize(PrimitiveWithInfer): | |||||
| def infer_shape(self, input_x): | def infer_shape(self, input_x): | ||||
| dim = len(input_x) | dim = len(input_x) | ||||
| validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) | |||||
| validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name) | |||||
| return input_x | return input_x | ||||
| def infer_dtype(self, input_x): | def infer_dtype(self, input_x): | ||||
| @@ -2481,10 +2481,10 @@ class DropoutDoMask(PrimitiveWithInfer): | |||||
| keep_prob_v = keep_prob['value'] | keep_prob_v = keep_prob['value'] | ||||
| if keep_prob_v is not None: | if keep_prob_v is not None: | ||||
| if isinstance(keep_prob['dtype'], type(mstype.tensor)): | if isinstance(keep_prob['dtype'], type(mstype.tensor)): | ||||
| validator.check_number_range('keep_prob', keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH, self.name) | |||||
| validator.check_float_range(keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH, 'keep_prob', self.name) | |||||
| else: | else: | ||||
| validator.check_value_type("keep_prob", keep_prob_v, [float], self.name) | validator.check_value_type("keep_prob", keep_prob_v, [float], self.name) | ||||
| validator.check_number_range('keep_prob', keep_prob_v, 0, 1, Rel.INC_BOTH, self.name) | |||||
| validator.check_float_range(keep_prob_v, 0, 1, Rel.INC_BOTH, 'keep_prob', self.name) | |||||
| out = {'shape': input_x_shape, | out = {'shape': input_x_shape, | ||||
| 'dtype': input_x['dtype'], | 'dtype': input_x['dtype'], | ||||
| @@ -2584,7 +2584,7 @@ class OneHot(PrimitiveWithInfer): | |||||
| # check shape | # check shape | ||||
| indices_shp = indices['shape'] | indices_shp = indices['shape'] | ||||
| validator.check_int_range("axis", self.axis, -1, len(indices_shp), Rel.INC_BOTH, self.name) | |||||
| validator.check_int_range(self.axis, -1, len(indices_shp), Rel.INC_BOTH, "axis", self.name) | |||||
| depth_val = depth['value'] | depth_val = depth['value'] | ||||
| validator.check_non_negative_int(depth_val, "depth", self.name) | validator.check_non_negative_int(depth_val, "depth", self.name) | ||||
| # create new dimension at end if self.axis is -1 | # create new dimension at end if self.axis is -1 | ||||
| @@ -2771,7 +2771,7 @@ class LSTM(PrimitiveWithInfer): | |||||
| self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name) | self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name) | ||||
| self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name) | self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name) | ||||
| self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) | self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) | ||||
| self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) | |||||
| self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name) | |||||
| if bidirectional: | if bidirectional: | ||||
| self.num_directions = 2 | self.num_directions = 2 | ||||
| @@ -3054,7 +3054,7 @@ class ROIAlign(PrimitiveWithInfer): | |||||
| validator.check_value_type("spatial_scale", spatial_scale, [float], self.name) | validator.check_value_type("spatial_scale", spatial_scale, [float], self.name) | ||||
| validator.check_value_type("sample_num", sample_num, [int], self.name) | validator.check_value_type("sample_num", sample_num, [int], self.name) | ||||
| validator.check_value_type("roi_end_mode", roi_end_mode, [int], self.name) | validator.check_value_type("roi_end_mode", roi_end_mode, [int], self.name) | ||||
| validator.check_int_range("roi_end_mode", roi_end_mode, 0, 1, Rel.INC_BOTH, self.name) | |||||
| validator.check_int_range(roi_end_mode, 0, 1, Rel.INC_BOTH, "roi_end_mode", self.name) | |||||
| self.pooled_height = pooled_height | self.pooled_height = pooled_height | ||||
| self.pooled_width = pooled_width | self.pooled_width = pooled_width | ||||
| self.spatial_scale = spatial_scale | self.spatial_scale = spatial_scale | ||||
| @@ -3502,9 +3502,9 @@ class FusedSparseFtrl(PrimitiveWithInfer): | |||||
| validator.check_value_type("l1", l1, [float], self.name) | validator.check_value_type("l1", l1, [float], self.name) | ||||
| validator.check_value_type("l2", l2, [float], self.name) | validator.check_value_type("l2", l2, [float], self.name) | ||||
| validator.check_value_type("lr_power", lr_power, [float], self.name) | validator.check_value_type("lr_power", lr_power, [float], self.name) | ||||
| self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name) | |||||
| self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name) | |||||
| self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name) | |||||
| self.lr = validator.check_positive_float(lr, "lr", self.name) | |||||
| self.l1 = validator.check_non_negative_float(l1, "l1", self.name) | |||||
| self.l2 = validator.check_non_negative_float(l2, "l2", self.name) | |||||
| self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) | self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) | ||||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | ||||
| @@ -4240,7 +4240,7 @@ class SparseApplyAdagrad(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, lr, update_slots=True, use_locking=False): | def __init__(self, lr, update_slots=True, use_locking=False): | ||||
| validator.check_value_type("lr", lr, [float], self.name) | validator.check_value_type("lr", lr, [float], self.name) | ||||
| validator.check_number_range("lr", lr, float("-inf"), float("inf"), Rel.INC_NEITHER, self.name) | |||||
| validator.check_is_float(lr, "lr", self.name) | |||||
| validator.check_value_type("update_slots", update_slots, [bool], self.name) | validator.check_value_type("update_slots", update_slots, [bool], self.name) | ||||
| validator.check_value_type("use_locking", use_locking, [bool], self.name) | validator.check_value_type("use_locking", use_locking, [bool], self.name) | ||||
| @@ -5142,9 +5142,9 @@ class SparseApplyFtrl(PrimitiveWithCheck): | |||||
| validator.check_value_type("l1", l1, [float], self.name) | validator.check_value_type("l1", l1, [float], self.name) | ||||
| validator.check_value_type("l2", l2, [float], self.name) | validator.check_value_type("l2", l2, [float], self.name) | ||||
| validator.check_value_type("lr_power", lr_power, [float], self.name) | validator.check_value_type("lr_power", lr_power, [float], self.name) | ||||
| self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name) | |||||
| self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name) | |||||
| self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name) | |||||
| self.lr = validator.check_positive_float(lr, "lr", self.name) | |||||
| self.l1 = validator.check_non_negative_float(l1, "l1", self.name) | |||||
| self.l2 = validator.check_non_negative_float(l2, "l2", self.name) | |||||
| self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) | self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) | ||||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | ||||
| self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'], | self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'], | ||||
| @@ -5239,9 +5239,9 @@ class SparseApplyFtrlV2(PrimitiveWithInfer): | |||||
| validator.check_value_type("l1", l1, [float], self.name) | validator.check_value_type("l1", l1, [float], self.name) | ||||
| validator.check_value_type("l2", l2, [float], self.name) | validator.check_value_type("l2", l2, [float], self.name) | ||||
| validator.check_value_type("lr_power", lr_power, [float], self.name) | validator.check_value_type("lr_power", lr_power, [float], self.name) | ||||
| self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name) | |||||
| self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name) | |||||
| self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name) | |||||
| self.lr = validator.check_positive_float(lr, "lr", self.name) | |||||
| self.l1 = validator.check_non_negative_float(l1, "l1", self.name) | |||||
| self.l2 = validator.check_non_negative_float(l2, "l2", self.name) | |||||
| self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) | self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) | ||||
| self.l2_shrinkage = validator.check_value_type("l2_shrinkage", l2_shrinkage, [float], self.name) | self.l2_shrinkage = validator.check_value_type("l2_shrinkage", l2_shrinkage, [float], self.name) | ||||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | ||||
| @@ -5285,7 +5285,7 @@ class Dropout(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, keep_prob=0.5): | def __init__(self, keep_prob=0.5): | ||||
| self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name) | |||||
| self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name) | |||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name) | validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name) | ||||
| @@ -5510,7 +5510,7 @@ class BasicLSTMCell(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'): | def __init__(self, keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'): | ||||
| self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | ||||
| self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name) | |||||
| self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name) | |||||
| self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | ||||
| self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name) | self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name) | ||||
| self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | ||||
| @@ -100,8 +100,8 @@ def _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps): | |||||
| lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps) | lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps) | ||||
| lr = float(lr_init) + lr_inc * (i + 1) | lr = float(lr_init) + lr_inc * (i + 1) | ||||
| else: | else: | ||||
| cosine_decay = 0.5 * (1 + math.cos(math.pi * (i-warmup_steps) / decay_steps)) | |||||
| lr = (lr_max-lr_end)*cosine_decay + lr_end | |||||
| cosine_decay = 0.5 * (1 + math.cos(math.pi * (i - warmup_steps) / decay_steps)) | |||||
| lr = (lr_max - lr_end) * cosine_decay + lr_end | |||||
| lr_each_step.append(lr) | lr_each_step.append(lr) | ||||
| return lr_each_step | return lr_each_step | ||||
| @@ -122,7 +122,7 @@ class MySparseGatherV2(PrimitiveWithInfer): | |||||
| axis_v = axis['value'] | axis_v = axis['value'] | ||||
| params_shp = params['shape'] | params_shp = params['shape'] | ||||
| rank = len(params_shp) | rank = len(params_shp) | ||||
| validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) | |||||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | |||||
| if axis_v < 0: | if axis_v < 0: | ||||
| axis_v += rank | axis_v += rank | ||||
| out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] | out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] | ||||
| @@ -208,10 +208,10 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): | |||||
| validator.check_value_type("beta2", beta2, [float], prim_name) | validator.check_value_type("beta2", beta2, [float], prim_name) | ||||
| validator.check_value_type("eps", eps, [float], prim_name) | validator.check_value_type("eps", eps, [float], prim_name) | ||||
| validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) | validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) | ||||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||||
| validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) | |||||
| validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) | |||||
| validator.check_positive_float(eps, "eps", prim_name) | |||||
| validator.check_non_negative_float(weight_decay, "weight_decay", prim_name) | |||||
| class AdamWeightDecaySparse(Optimizer): | class AdamWeightDecaySparse(Optimizer): | ||||
| @@ -14,55 +14,97 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ test checkparameter """ | """ test checkparameter """ | ||||
| import pytest | import pytest | ||||
| from mindspore._checkparam import check_int, check_input_format, Validator, twice | |||||
| import numpy as np | |||||
| from mindspore._checkparam import check_input_format, Validator, twice, Rel | |||||
| kernel_size = 5 | kernel_size = 5 | ||||
| kernel_size1 = twice(kernel_size) | kernel_size1 = twice(kernel_size) | ||||
| assert kernel_size1 == (5, 5) | assert kernel_size1 == (5, 5) | ||||
| def test_check_integer1(): | |||||
| with pytest.raises(TypeError): | |||||
| Validator.check_integer("input", 0, Rel.GE, "number") | |||||
| def test_check_int_1(): | |||||
| assert check_int(3) == 3 | |||||
| def test_check_integer2(): | |||||
| with pytest.raises(ValueError): | |||||
| Validator.check_integer(-1, 0, Rel.GE, "number") | |||||
| def test_check_integer3(): | |||||
| input = np.random.randint(0, 100) | |||||
| assert Validator.check_integer(input, 0, Rel.GE, "number") == input | |||||
| def check_int_positive_1(): | |||||
| with pytest.raises(ValueError): | |||||
| Validator.check_positive_int(-1) | |||||
| def test_check_int1(): | |||||
| input = np.random.randint(-100, 100) | |||||
| assert Validator.check_is_int(input) == input | |||||
| def test_check_int2(): | |||||
| with pytest.raises(TypeError): | |||||
| Validator.check_is_int(3.3) | |||||
| def test_NCHW1(): | |||||
| assert check_input_format("NCHW") == "NCHW" | |||||
| def test_check_int3(): | |||||
| with pytest.raises(TypeError): | |||||
| Validator.check_is_int("str") | |||||
| def test_check_int4(): | |||||
| with pytest.raises(TypeError): | |||||
| Validator.check_is_int(True) | |||||
| def test_check_is_int5(): | |||||
| with pytest.raises(TypeError): | |||||
| Validator.check_is_int(True) | |||||
| with pytest.raises(TypeError): | |||||
| Validator.check_is_int(False) | |||||
| def test_check_positive_int1(): | |||||
| input = np.random.randint(0, 100) | |||||
| assert Validator.check_positive_int(input) == input | |||||
| def test_NCHW3(): | |||||
| def test_check_positive_int2(): | |||||
| input = np.random.randint(-100, 0) | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| check_input_format("rt") | |||||
| Validator.check_positive_int(input) | |||||
| def test_check_positive_int3(): | |||||
| with pytest.raises(ValueError): | |||||
| Validator.check_positive_int(3.3) | |||||
| def test_check_int_2(): | |||||
| def test_check_positive_int4(): | |||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| check_int(3.3) | |||||
| Validator.check_positive_int("str") | |||||
| def test_check_negative_int1(): | |||||
| input = np.random.randint(-100, -1) | |||||
| assert Validator.check_negative_int(input) == input | |||||
| def test_check_int_3(): | |||||
| with pytest.raises(TypeError): | |||||
| check_int("str") | |||||
| def test_check_negative_int2(): | |||||
| input = np.random.randint(0, 100) | |||||
| with pytest.raises(ValueError): | |||||
| Validator.check_negative_int(input) | |||||
| def test_check_negative_int3(): | |||||
| with pytest.raises(ValueError): | |||||
| Validator.check_negative_int(3.3) | |||||
| def test_check_int_4(): | |||||
| def test_check_negative_int4(): | |||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| check_int(True) | |||||
| Validator.check_negative_int("str") | |||||
| def test_check_non_positive_int1(): | |||||
| input = np.random.randint(-100, 0) | |||||
| assert Validator.check_non_positive_int(input) == input | |||||
| def test_check_int_5(): | |||||
| check_int(0) | |||||
| check_int(1) | |||||
| with pytest.raises(TypeError): | |||||
| check_int(True) | |||||
| with pytest.raises(TypeError): | |||||
| check_int(False) | |||||
| def test_check_non_positive_int2(): | |||||
| input = np.random.randint(1, 100) | |||||
| with pytest.raises(ValueError): | |||||
| Validator.check_non_positive_int(input) | |||||
| def test_check_non_positive_int3(): | |||||
| with pytest.raises(ValueError): | |||||
| Validator.check_non_positive_int(3.3) | |||||
| def test_check_non_positive_int4(): | |||||
| with pytest.raises(TypeError): | |||||
| Validator.check_non_positive_int("str") | |||||
| def test_check_bool_1(): | def test_check_bool_1(): | ||||
| assert Validator.check_bool(True) | assert Validator.check_bool(True) | ||||