Merge pull request !7339 from qujianwei/mastertags/v1.1.0
| @@ -94,10 +94,10 @@ rel_strs = { | |||
| def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None): | |||
| """ | |||
| Check argument integer. | |||
| Check argument integer. | |||
| Usage: | |||
| - number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0 | |||
| Usage: | |||
| - number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0 | |||
| """ | |||
| rel_fn = Rel.get_fns(rel) | |||
| type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool) | |||
| @@ -122,13 +122,33 @@ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None): | |||
| """ | |||
| prim_name = f'in \'{prim_name}\'' if prim_name else '' | |||
| arg_name = f'\'{prim_name}\'' if arg_name else 'Input value' | |||
| if isinstance(arg_value, arg_type): | |||
| if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool): | |||
| if math.isinf(arg_value) or math.isnan(arg_value): | |||
| raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.') | |||
| return arg_value | |||
| raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`') | |||
| def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None): | |||
| """ | |||
| Method for checking whether an int value is in some range. | |||
| Usage: | |||
| - number = check_number_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number", float) # number in [0.0, 1.0] | |||
| - number = check_number_range(number, 0, 1, Rel.INC_NEITHER, "number", int) # number in [0, 1] | |||
| """ | |||
| prim_name = f'in `{prim_name}`' if prim_name else '' | |||
| arg_name = f'`{arg_name}`' if arg_name else '' | |||
| rel_fn = Rel.get_fns(rel) | |||
| type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool) | |||
| excp_cls = TypeError if type_mismatch else ValueError | |||
| if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit): | |||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||
| raise excp_cls("{} {} should be in range of {}, but got {:.3f} with type {}.".format( | |||
| arg_name, prim_name, rel_str, arg_value, type(arg_value).__name__)) | |||
| return arg_value | |||
| class Validator: | |||
| """validator for checking input parameters""" | |||
| @@ -147,16 +167,13 @@ class Validator: | |||
| @staticmethod | |||
| def check_integer(arg_name, arg_value, value, rel, prim_name=None): | |||
| """Check argument is integer""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) | |||
| excp_cls = TypeError if type_mismatch else ValueError | |||
| if type_mismatch or not rel_fn(arg_value, value): | |||
| rel_str = Rel.get_strs(rel).format(value) | |||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | |||
| raise excp_cls(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`' | |||
| f' with type `{type(arg_value).__name__}`.') | |||
| return arg_value | |||
| """ | |||
| Checks input integer value `arg_value` compare to `value`. | |||
| Usage: | |||
| - number = check_integer(number, 0, Rel.GE, "number", None) # number >= 0 | |||
| """ | |||
| return check_number(arg_value, value, rel, int, arg_name, prim_name) | |||
| @staticmethod | |||
| def check_is_int(arg_value, arg_name=None, prim_name=None): | |||
| @@ -168,7 +185,7 @@ class Validator: | |||
| - number = check_is_int(number, int, "bias") | |||
| - number = check_is_int(number, int, "bias", "bias_class") | |||
| """ | |||
| check_is_number(arg_value, int, arg_name, prim_name) | |||
| return check_is_number(arg_value, int, arg_name, prim_name) | |||
| @staticmethod | |||
| def check_positive_int(arg_value, arg_name=None, prim_name=None): | |||
| @@ -214,6 +231,16 @@ class Validator: | |||
| """ | |||
| return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name) | |||
| @staticmethod | |||
| def check_float(arg_value, value, rel, arg_name=None, prim_name=None): | |||
| """ | |||
| Checks input float value `arg_value` compare to `value`. | |||
| Usage: | |||
| - number = check_float(number, 0.0, Rel.GE, "number", None) # number >= 0 | |||
| """ | |||
| return check_number(arg_value, value, rel, float, arg_name, prim_name) | |||
| @staticmethod | |||
| def check_is_float(arg_value, arg_name=None, prim_name=None): | |||
| """ | |||
| @@ -224,7 +251,7 @@ class Validator: | |||
| - number = check_is_float(number, int, "bias") | |||
| - number = check_is_float(number, int, "bias", "bias_class") | |||
| """ | |||
| check_is_number(arg_value, float, arg_name, prim_name) | |||
| return check_is_number(arg_value, float, arg_name, prim_name) | |||
| @staticmethod | |||
| def check_positive_float(arg_value, arg_name=None, prim_name=None): | |||
| @@ -302,25 +329,26 @@ class Validator: | |||
| return arg_value | |||
| @staticmethod | |||
| def check_int_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): | |||
| """Method for checking whether an int value is in some range.""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) | |||
| excp_cls = TypeError if type_mismatch else ValueError | |||
| if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit): | |||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||
| raise excp_cls(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},' | |||
| f' but got `{arg_value}` with type `{type(arg_value).__name__}`.') | |||
| return arg_value | |||
| def check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None): | |||
| """ | |||
| Method for checking whether input value is in int range. | |||
| Usage: | |||
| - number = check_int_range(number, 0, 1, Rel.INC_NEITHER) # number in [0, 1] | |||
| - number = check_int_range(number, 0, 1, Rel.INC_NEITHER, "number") # number in [0, 1] | |||
| """ | |||
| return check_number_range(arg_value, lower_limit, upper_limit, rel, int, arg_name, prim_name) | |||
| @staticmethod | |||
| def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): | |||
| """Method for checking whether a numeric value is in some range.""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| if not rel_fn(arg_value, lower_limit, upper_limit): | |||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.') | |||
| return arg_value | |||
| def check_float_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None): | |||
| """ | |||
| Method for checking whether input value is in float range. | |||
| Usage: | |||
| - number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER) # number in [0.0, 1.0] | |||
| - number = check_float_range(number, 0.0, 1.0, Rel.INC_NEITHER, "number") # number in [0.0, 1.0] | |||
| """ | |||
| return check_number_range(arg_value, lower_limit, upper_limit, rel, float, arg_name, prim_name) | |||
| @staticmethod | |||
| def check_string(arg_value, valid_values, arg_name=None, prim_name=None): | |||
| @@ -502,13 +530,6 @@ class Validator: | |||
| f'{tuple(exp_shape)}, but got {shape}.') | |||
| def check_int(input_param): | |||
| """Int type judgment.""" | |||
| if isinstance(input_param, int) and not isinstance(input_param, bool): | |||
| return input_param | |||
| raise TypeError("Input type must be int!") | |||
| def check_int_zero_one(input_param): | |||
| """Judge whether it is 0 or 1.""" | |||
| if input_param in (0, 1): | |||
| @@ -233,7 +233,7 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): | |||
| """ | |||
| if not isinstance(min_lr, float): | |||
| raise TypeError("min_lr must be float.") | |||
| validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None) | |||
| validator.check_non_negative_float(min_lr, "min_lr", None) | |||
| validator.check_positive_float(max_lr, 'max_lr') | |||
| validator.check_is_float(max_lr, 'max_lr') | |||
| validator.check_positive_int(total_step, 'total_step') | |||
| @@ -303,7 +303,7 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e | |||
| validator.check_is_float(learning_rate, 'learning_rate') | |||
| if not isinstance(end_learning_rate, float): | |||
| raise TypeError("end_learning_rate must be float.") | |||
| validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) | |||
| validator.check_non_negative_float(end_learning_rate, "end_learning_rate", None) | |||
| validator.check_positive_float(power, 'power') | |||
| validator.check_is_float(power, 'power') | |||
| validator.check_positive_int(total_step, 'total_step') | |||
| @@ -356,7 +356,7 @@ def warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch): | |||
| """ | |||
| if not isinstance(learning_rate, float): | |||
| raise TypeError("learning_rate must be float.") | |||
| validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) | |||
| validator.check_non_negative_float(learning_rate, "learning_rate", None) | |||
| validator.check_positive_int(warmup_epoch, 'warmup_epoch') | |||
| validator.check_positive_int(total_step, 'total_step') | |||
| validator.check_positive_int(step_per_epoch, 'step_per_epoch') | |||
| @@ -451,8 +451,7 @@ class CentralCrop(Cell): | |||
| def __init__(self, central_fraction): | |||
| super(CentralCrop, self).__init__() | |||
| validator.check_value_type("central_fraction", central_fraction, [float], self.cls_name) | |||
| self.central_fraction = validator.check_number_range('central_fraction', central_fraction, | |||
| 0.0, 1.0, Rel.INC_RIGHT, self.cls_name) | |||
| self.central_fraction = validator.check_float_range(0.0, 1.0, Rel.INC_RIGHT, 'central_fraction', central_fraction, self.cls_name) | |||
| self.slice = P.Slice() | |||
| def construct(self, image): | |||
| @@ -254,7 +254,7 @@ class CosineDecayLR(LearningRateSchedule): | |||
| super(CosineDecayLR, self).__init__() | |||
| if not isinstance(min_lr, float): | |||
| raise TypeError("min_lr must be float.") | |||
| validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||
| validator.check_non_negative_float(min_lr, "min_lr", self.cls_name) | |||
| validator.check_positive_float(max_lr, 'max_lr', self.cls_name) | |||
| validator.check_is_float(max_lr, 'max_lr', self.cls_name) | |||
| validator.check_positive_int(decay_steps, "decay_steps", self.cls_name) | |||
| @@ -322,8 +322,7 @@ class PolynomialDecayLR(LearningRateSchedule): | |||
| validator.check_is_float(learning_rate, 'learning_rate') | |||
| if not isinstance(end_learning_rate, float): | |||
| raise TypeError("end_learning_rate must be float.") | |||
| validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, | |||
| self.cls_name) | |||
| validator.check_non_negative_float(end_learning_rate, "end_learning_rate", self.cls_name) | |||
| validator.check_positive_int(decay_steps, 'decay_steps', self.cls_name) | |||
| validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name) | |||
| validator.check_positive_float(power, 'power', self.cls_name) | |||
| @@ -387,7 +386,7 @@ class WarmUpLR(LearningRateSchedule): | |||
| super(WarmUpLR, self).__init__() | |||
| if not isinstance(learning_rate, float): | |||
| raise TypeError("learning_rate must be float.") | |||
| validator.check_number_range("learning_rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||
| validator.check_non_negative_float(learning_rate, "learning_rate", self.cls_name) | |||
| validator.check_positive_int(warmup_steps, 'warmup_steps', self.cls_name) | |||
| self.warmup_steps = warmup_steps | |||
| self.learning_rate = learning_rate | |||
| @@ -368,7 +368,7 @@ class CosineEmbeddingLoss(_Loss): | |||
| self.reduce_sum = P.ReduceSum() | |||
| self.maximum = P.Maximum() | |||
| validator.check_value_type("margin", margin, [float], self.cls_name) | |||
| self.margin = validator.check_number_range("margin", margin, -1.0, 1.0, Rel.INC_BOTH, self.cls_name) | |||
| self.margin = validator.check_float_range(margin, -1.0, 1.0, Rel.INC_BOTH, "margin", self.cls_name) | |||
| def construct(self, x1, x2, y): | |||
| F.same_type_shape(x1, x2) | |||
| @@ -126,9 +126,9 @@ def _check_param_value(beta1, beta2, eps, prim_name): | |||
| validator.check_value_type("beta1", beta1, [float], prim_name) | |||
| validator.check_value_type("beta2", beta2, [float], prim_name) | |||
| validator.check_value_type("eps", eps, [float], prim_name) | |||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) | |||
| validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) | |||
| validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) | |||
| validator.check_positive_float(eps, "eps", prim_name) | |||
| class Adam(Optimizer): | |||
| @@ -177,9 +177,9 @@ def _check_param_value(beta1, beta2, eps, prim_name): | |||
| validator.check_value_type("beta1", beta1, [float], prim_name) | |||
| validator.check_value_type("beta2", beta2, [float], prim_name) | |||
| validator.check_value_type("eps", eps, [float], prim_name) | |||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) | |||
| validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) | |||
| validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) | |||
| validator.check_positive_float(eps, "eps", prim_name) | |||
| class Lamb(Optimizer): | |||
| @@ -70,10 +70,10 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): | |||
| validator.check_value_type("beta2", beta2, [float], prim_name) | |||
| validator.check_value_type("eps", eps, [float], prim_name) | |||
| validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) | |||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||
| validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) | |||
| validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) | |||
| validator.check_positive_float(eps, "eps", prim_name) | |||
| validator.check_non_negative_float(weight_decay, "weight_decay", prim_name) | |||
| class LazyAdam(Optimizer): | |||
| @@ -100,7 +100,7 @@ class Optimizer(Cell): | |||
| if isinstance(loss_scale, int): | |||
| loss_scale = float(loss_scale) | |||
| validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) | |||
| validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name) | |||
| validator.check_positive_float(loss_scale, "loss_scale", self.cls_name) | |||
| self.loss_scale = loss_scale | |||
| weight_decay = self._preprocess_weight_decay(weight_decay) | |||
| @@ -221,7 +221,7 @@ class Optimizer(Cell): | |||
| """Check weight decay, and convert int to float.""" | |||
| if isinstance(weight_decay, (float, int)): | |||
| weight_decay = float(weight_decay) | |||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||
| validator.check_non_negative_float(weight_decay, "weight_decay", self.cls_name) | |||
| return weight_decay | |||
| raise TypeError("Weight decay should be int or float.") | |||
| @@ -229,7 +229,7 @@ class Optimizer(Cell): | |||
| """Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule.""" | |||
| if isinstance(learning_rate, (float, int)): | |||
| learning_rate = float(learning_rate) | |||
| validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||
| validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name) | |||
| return learning_rate | |||
| if isinstance(learning_rate, Tensor) and learning_rate.dim() == 0: | |||
| return learning_rate | |||
| @@ -45,9 +45,9 @@ def _check_param_value(accum, l1, l2, use_locking, prim_name=None): | |||
| validator.check_value_type("l1", l1, [float], prim_name) | |||
| validator.check_value_type("l2", l2, [float], prim_name) | |||
| validator.check_value_type("use_locking", use_locking, [bool], prim_name) | |||
| validator.check_number_range("accum", accum, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||
| validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||
| validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||
| validator.check_non_negative_float(accum, "accum", prim_name) | |||
| validator.check_non_negative_float(l1, "l1", prim_name) | |||
| validator.check_non_negative_float(l2, "l2", prim_name) | |||
| class ProximalAdagrad(Optimizer): | |||
| @@ -154,11 +154,11 @@ class RMSProp(Optimizer): | |||
| use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0): | |||
| super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale) | |||
| validator.check_value_type("decay", decay, [float], self.cls_name) | |||
| validator.check_number_range("decay", decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||
| validator.check_non_negative_float(decay, "decay", self.cls_name) | |||
| validator.check_value_type("momentum", momentum, [float], self.cls_name) | |||
| validator.check_number_range("momentum", momentum, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||
| validator.check_non_negative_float(momentum, "momentum", self.cls_name) | |||
| validator.check_value_type("epsilon", epsilon, [float], self.cls_name) | |||
| validator.check_number_range("epsilon", epsilon, 0.0, float("inf"), Rel.INC_NEITHER, self.cls_name) | |||
| validator.check_positive_float(epsilon, "epsilon", self.cls_name) | |||
| validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) | |||
| validator.check_value_type("centered", centered, [bool], self.cls_name) | |||
| @@ -69,7 +69,7 @@ def get_concat_offset(x_shp, x_type, axis, prim_name): | |||
| validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name) | |||
| validator.check_positive_int(len(x_shp[0]), "len of x_shp[0]", prim_name) | |||
| rank_base = len(x_shp[0]) | |||
| validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name) | |||
| validator.check_int_range(axis, -rank_base - 1, rank_base, Rel.INC_BOTH, 'axis', prim_name) | |||
| if axis < 0: | |||
| axis = axis + rank_base | |||
| all_shp = x_shp[0][axis] | |||
| @@ -188,7 +188,7 @@ class BatchNormGrad(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, is_training=False, epsilon=1e-5): | |||
| self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | |||
| self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) | |||
| self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||
| self.add_prim_attr('data_format', "NCHW") | |||
| def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape): | |||
| @@ -485,7 +485,7 @@ class DropoutGrad(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, keep_prob=0.5): | |||
| self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name) | |||
| self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name) | |||
| def infer_shape(self, dy_shape, mask_shape): | |||
| return dy_shape | |||
| @@ -902,7 +902,7 @@ class LogSoftmaxGrad(PrimitiveWithInfer): | |||
| def infer_shape(self, dout, logits): | |||
| rank = len(logits) | |||
| validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH, self.name) | |||
| validator.check_int_range(self.axis, -rank - 1, rank, Rel.INC_BOTH, 'axis', self.name) | |||
| return logits | |||
| def infer_dtype(self, dout, logits): | |||
| @@ -921,7 +921,7 @@ class LSTMGradData(PrimitiveWithInfer): | |||
| self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) | |||
| self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) | |||
| self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) | |||
| self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name) | |||
| if bidirectional: | |||
| self.num_directions = 2 | |||
| @@ -970,7 +970,7 @@ class LSTMGradWeight(PrimitiveWithInfer): | |||
| self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) | |||
| self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) | |||
| self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) | |||
| self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name) | |||
| if bidirectional: | |||
| self.num_directions = 2 | |||
| @@ -1005,7 +1005,7 @@ class LSTMGrad(PrimitiveWithInfer): | |||
| self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) | |||
| self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) | |||
| self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) | |||
| self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name) | |||
| if bidirectional: | |||
| self.num_directions = 2 | |||
| @@ -1652,7 +1652,7 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, keep_prob): | |||
| self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | |||
| self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name) | |||
| self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, dgate_shape, w_shape): | |||
| @@ -76,8 +76,7 @@ class MinMaxUpdatePerLayer(PrimitiveWithInfer): | |||
| f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | |||
| self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | |||
| self.ema_decay = validator.check_number_range( | |||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name) | |||
| self.init_prim_io_names(inputs=['x', 'min', 'max'], | |||
| outputs=['min_up', 'max_up']) | |||
| @@ -136,10 +135,9 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer): | |||
| f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") | |||
| self.ema = validator.check_value_type('ema', ema, (bool,), self.name) | |||
| self.ema_decay = validator.check_number_range( | |||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name) | |||
| if self.is_ascend: | |||
| self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name) | |||
| else: | |||
| self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name) | |||
| self.init_prim_io_names( | |||
| @@ -222,10 +220,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer): | |||
| 'symmetric', symmetric, (bool,), self.name) | |||
| self.narrow_range = validator.check_value_type( | |||
| 'narrow_range', narrow_range, (bool,), self.name) | |||
| self.training = validator.check_value_type( | |||
| 'training', training, (bool,), self.name) | |||
| self.ema_decay = validator.check_number_range( | |||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.training = validator.check_value_type('training', training, (bool,), self.name) | |||
| self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name) | |||
| self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) | |||
| self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name) | |||
| self.init_prim_io_names(inputs=['x', 'min', 'max'], | |||
| @@ -366,12 +362,11 @@ class FakeQuantPerChannel(PrimitiveWithInfer): | |||
| 'narrow_range', narrow_range, (bool,), self.name) | |||
| self.training = validator.check_value_type( | |||
| 'training', training, (bool,), self.name) | |||
| self.ema_decay = validator.check_number_range( | |||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name) | |||
| self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) | |||
| self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name) | |||
| if self.is_ascend: | |||
| self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name) | |||
| else: | |||
| self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name) | |||
| self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) | |||
| @@ -495,7 +490,7 @@ class BatchNormFold(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): | |||
| """Initialize batch norm fold layer""" | |||
| self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) | |||
| self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) | |||
| self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | |||
| self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) | |||
| @@ -806,7 +801,7 @@ class BatchNormFoldD(PrimitiveWithInfer): | |||
| def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): | |||
| """Initialize _BatchNormFold layer""" | |||
| from mindspore.ops._op_impl._custom_op import batchnorm_fold | |||
| self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) | |||
| self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) | |||
| self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) | |||
| self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) | |||
| @@ -129,7 +129,7 @@ class ExpandDims(PrimitiveWithInfer): | |||
| x_shape = list(x['shape']) | |||
| axis_v = axis['value'] | |||
| rank = len(x_shape) | |||
| validator.check_int_range('axis', axis_v, -rank - 1, rank, Rel.INC_BOTH, self.name) | |||
| validator.check_int_range(axis_v, -rank - 1, rank, Rel.INC_BOTH, 'axis', self.name) | |||
| value = None | |||
| if x['value'] is not None: | |||
| value = x['value'].asnumpy() | |||
| @@ -534,7 +534,7 @@ class Squeeze(PrimitiveWithInfer): | |||
| ret = [d for d in x_shape if d != 1] | |||
| else: | |||
| for a in axis: | |||
| validator.check_int_range('axis or its elements', a, -ndim, ndim - 1, Rel.INC_BOTH, self.name) | |||
| validator.check_int_range(a, -ndim, ndim - 1, Rel.INC_BOTH, 'axis or its elements', self.name) | |||
| if x_shape[a] != 1: | |||
| raise ValueError('Cannot select an axis to squeeze out which has size not equal to one.') | |||
| ret = [x_shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis)] | |||
| @@ -658,7 +658,7 @@ class GatherV2(PrimitiveWithCheck): | |||
| axis_v = axis['value'] | |||
| params_shp = params['shape'] | |||
| rank = len(params_shp) | |||
| validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) | |||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | |||
| if axis_v < 0: | |||
| axis_v += rank | |||
| @@ -777,7 +777,7 @@ class Split(PrimitiveWithInfer): | |||
| validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | |||
| x_shape = list(x['shape']) | |||
| dim = len(x_shape) | |||
| validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) | |||
| validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name) | |||
| validator.check_positive_int(self.output_num, "output_num", self.name) | |||
| output_valid_check = x_shape[self.axis] % self.output_num | |||
| if output_valid_check != 0: | |||
| @@ -1224,7 +1224,7 @@ class Argmax(PrimitiveWithInfer): | |||
| if axis is None: | |||
| axis = 0 | |||
| x_rank = len(x_shape) | |||
| validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name) | |||
| validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name) | |||
| axis = axis + x_rank if axis < 0 else axis | |||
| ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis] | |||
| return ouput_shape | |||
| @@ -1272,7 +1272,7 @@ class Argmin(PrimitiveWithInfer): | |||
| if axis is None: | |||
| axis = 0 | |||
| x_rank = len(x_shape) | |||
| validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name) | |||
| validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name) | |||
| axis = axis + x_rank if axis < 0 else axis | |||
| ouput_shape = [x_shape[i] for i in range(x_rank) if i != axis] | |||
| return ouput_shape | |||
| @@ -1325,7 +1325,7 @@ class ArgMaxWithValue(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape): | |||
| axis = self.axis | |||
| x_rank = len(x_shape) | |||
| validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name) | |||
| validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name) | |||
| ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name) | |||
| return ouput_shape, ouput_shape | |||
| @@ -1377,7 +1377,7 @@ class ArgMinWithValue(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape): | |||
| axis = self.axis | |||
| x_rank = len(x_shape) | |||
| validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT, self.name) | |||
| validator.check_int_range(axis, -x_rank, x_rank, Rel.INC_LEFT, "axis", self.name) | |||
| ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name) | |||
| return ouput_shape, ouput_shape | |||
| @@ -1760,7 +1760,7 @@ def _get_pack_shape(x_shape, x_type, axis, prim_name): | |||
| rank_base = len(x_shape[0]) | |||
| N = len(x_shape) | |||
| out_shape = x_shape[0] | |||
| validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name) | |||
| validator.check_int_range(axis, -rank_base - 1, rank_base, Rel.INC_BOTH, 'axis', prim_name) | |||
| if axis < 0: | |||
| axis = axis + rank_base + 1 | |||
| for i in range(1, N): | |||
| @@ -1863,7 +1863,7 @@ class Unpack(PrimitiveWithInfer): | |||
| validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) | |||
| x_shape = list(x['shape']) | |||
| dim = len(x_shape) | |||
| validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) | |||
| validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name) | |||
| if self.axis < 0: | |||
| self.axis = self.axis + dim | |||
| output_num = x_shape[self.axis] | |||
| @@ -1965,7 +1965,7 @@ class ReverseV2(PrimitiveWithInfer): | |||
| def infer_shape(self, x_shape): | |||
| dim = len(x_shape) | |||
| for i, each in enumerate(self.axis): | |||
| validator.check_int_range(f'axis[{i}]', each, -dim, dim, Rel.INC_LEFT, self.name) | |||
| validator.check_int_range(each, -dim, dim, Rel.INC_LEFT, f'axis[{i}]', self.name) | |||
| return x_shape | |||
| def infer_dtype(self, x_dtype): | |||
| @@ -206,7 +206,7 @@ class _HostAllGather(PrimitiveWithInfer): | |||
| validator.check_value_type('group', group, (tuple, list), self.name) | |||
| validator.check_integer("group size", len(group), 2, Rel.GE, self.name) | |||
| for r in group: | |||
| validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name) | |||
| validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name) | |||
| validator.check_value_type("rank_id", r, (int,), self.name) | |||
| self.group_size = len(group) | |||
| self.add_prim_attr('group', group) | |||
| @@ -315,7 +315,7 @@ class _HostReduceScatter(PrimitiveWithInfer): | |||
| validator.check_value_type('group', group, (tuple, list), self.name) | |||
| validator.check_integer("group size", len(group), 2, Rel.GE, self.name) | |||
| for r in group: | |||
| validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name) | |||
| validator.check_int_range(r, 0, 7, Rel.INC_BOTH, "rank_id", self.name) | |||
| validator.check_value_type("rank_id", r, (int,), self.name) | |||
| self.op = op | |||
| self.group_size = len(group) | |||
| @@ -70,8 +70,7 @@ class ControlDepend(Primitive): | |||
| @prim_attr_register | |||
| def __init__(self, depend_mode=0): | |||
| """init""" | |||
| validator.check_int_range( | |||
| "depend_mode", depend_mode, 0, 1, Rel.INC_BOTH, self.name) | |||
| validator.check_int_range(depend_mode, 0, 1, Rel.INC_BOTH, "depend_mode", self.name) | |||
| def __call__(self, src, dst): | |||
| return src | |||
| @@ -31,7 +31,7 @@ def _infer_shape_reduce(x, axis, keep_dims, prim_name): | |||
| """Common infer for reduce operator""" | |||
| def reduce_one_axis(one_axis): | |||
| validator.check_int_range('axis', one_axis, -dim, dim, Rel.INC_LEFT, prim_name) | |||
| validator.check_int_range(one_axis, -dim, dim, Rel.INC_LEFT, 'axis', prim_name) | |||
| if one_axis < 0: | |||
| one_axis += dim | |||
| axis_reduce.add(one_axis) | |||
| @@ -149,7 +149,7 @@ class Softmax(PrimitiveWithInfer): | |||
| validator.check_integer("length of axis", len(self.axis), 1, Rel.GE, self.name) | |||
| rank = len(logits) | |||
| for axis_v in self.axis: | |||
| validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) | |||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | |||
| return logits | |||
| def infer_dtype(self, logits): | |||
| @@ -193,7 +193,7 @@ class LogSoftmax(PrimitiveWithInfer): | |||
| def infer_shape(self, logits): | |||
| rank = len(logits) | |||
| validator.check_int_range('axis', self.axis, -rank, rank, Rel.INC_LEFT, self.name) | |||
| validator.check_int_range(self.axis, -rank, rank, Rel.INC_LEFT, 'axis', self.name) | |||
| return logits | |||
| def infer_dtype(self, logits): | |||
| @@ -637,8 +637,8 @@ class FusedBatchNorm(Primitive): | |||
| self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], | |||
| outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance']) | |||
| self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) | |||
| self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) | |||
| self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||
| self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) | |||
| self._update_parameter = True | |||
| @@ -710,8 +710,8 @@ class FusedBatchNormEx(PrimitiveWithInfer): | |||
| self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], | |||
| outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve']) | |||
| self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) | |||
| self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) | |||
| self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||
| self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) | |||
| self._update_parameter = True | |||
| self.add_prim_attr('data_format', "NCHW") | |||
| @@ -818,8 +818,8 @@ class BNTrainingUpdate(PrimitiveWithInfer): | |||
| validator.check_value_type("isRef", isRef, [bool], self.name) | |||
| validator.check_value_type("epsilon", epsilon, [float], self.name) | |||
| validator.check_value_type("factor", factor, [float], self.name) | |||
| self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, 'BNTrainingUpdate') | |||
| self.factor = validator.check_number_range('factor', factor, 0, 1, Rel.INC_BOTH, 'BNTrainingUpdate') | |||
| self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', 'BNTrainingUpdate') | |||
| self.factor = validator.check_float_range(factor, 0, 1, Rel.INC_BOTH, 'factor', 'BNTrainingUpdate') | |||
| def infer_shape(self, x, sum, square_sum, scale, b, mean, variance): | |||
| validator.check_integer("x rank", len(x), 4, Rel.EQ, self.name) | |||
| @@ -898,7 +898,7 @@ class BatchNorm(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, is_training=False, epsilon=1e-5): | |||
| validator.check_value_type('is_training', is_training, (bool,), self.name) | |||
| validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) | |||
| validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) | |||
| self.add_prim_attr('data_format', "NCHW") | |||
| self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], | |||
| outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) | |||
| @@ -2383,7 +2383,7 @@ class L2Normalize(PrimitiveWithInfer): | |||
| def infer_shape(self, input_x): | |||
| dim = len(input_x) | |||
| validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) | |||
| validator.check_int_range(self.axis, -dim, dim, Rel.INC_LEFT, 'axis value', self.name) | |||
| return input_x | |||
| def infer_dtype(self, input_x): | |||
| @@ -2481,10 +2481,10 @@ class DropoutDoMask(PrimitiveWithInfer): | |||
| keep_prob_v = keep_prob['value'] | |||
| if keep_prob_v is not None: | |||
| if isinstance(keep_prob['dtype'], type(mstype.tensor)): | |||
| validator.check_number_range('keep_prob', keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH, self.name) | |||
| validator.check_float_range(keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH, 'keep_prob', self.name) | |||
| else: | |||
| validator.check_value_type("keep_prob", keep_prob_v, [float], self.name) | |||
| validator.check_number_range('keep_prob', keep_prob_v, 0, 1, Rel.INC_BOTH, self.name) | |||
| validator.check_float_range(keep_prob_v, 0, 1, Rel.INC_BOTH, 'keep_prob', self.name) | |||
| out = {'shape': input_x_shape, | |||
| 'dtype': input_x['dtype'], | |||
| @@ -2584,7 +2584,7 @@ class OneHot(PrimitiveWithInfer): | |||
| # check shape | |||
| indices_shp = indices['shape'] | |||
| validator.check_int_range("axis", self.axis, -1, len(indices_shp), Rel.INC_BOTH, self.name) | |||
| validator.check_int_range(self.axis, -1, len(indices_shp), Rel.INC_BOTH, "axis", self.name) | |||
| depth_val = depth['value'] | |||
| validator.check_non_negative_int(depth_val, "depth", self.name) | |||
| # create new dimension at end if self.axis is -1 | |||
| @@ -2771,7 +2771,7 @@ class LSTM(PrimitiveWithInfer): | |||
| self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name) | |||
| self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name) | |||
| self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) | |||
| self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) | |||
| self.dropout = validator.check_float_range(dropout, 0, 1, Rel.INC_BOTH, 'dropout', self.name) | |||
| if bidirectional: | |||
| self.num_directions = 2 | |||
| @@ -3054,7 +3054,7 @@ class ROIAlign(PrimitiveWithInfer): | |||
| validator.check_value_type("spatial_scale", spatial_scale, [float], self.name) | |||
| validator.check_value_type("sample_num", sample_num, [int], self.name) | |||
| validator.check_value_type("roi_end_mode", roi_end_mode, [int], self.name) | |||
| validator.check_int_range("roi_end_mode", roi_end_mode, 0, 1, Rel.INC_BOTH, self.name) | |||
| validator.check_int_range(roi_end_mode, 0, 1, Rel.INC_BOTH, "roi_end_mode", self.name) | |||
| self.pooled_height = pooled_height | |||
| self.pooled_width = pooled_width | |||
| self.spatial_scale = spatial_scale | |||
| @@ -3502,9 +3502,9 @@ class FusedSparseFtrl(PrimitiveWithInfer): | |||
| validator.check_value_type("l1", l1, [float], self.name) | |||
| validator.check_value_type("l2", l2, [float], self.name) | |||
| validator.check_value_type("lr_power", lr_power, [float], self.name) | |||
| self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name) | |||
| self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name) | |||
| self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name) | |||
| self.lr = validator.check_positive_float(lr, "lr", self.name) | |||
| self.l1 = validator.check_non_negative_float(l1, "l1", self.name) | |||
| self.l2 = validator.check_non_negative_float(l2, "l2", self.name) | |||
| self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) | |||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||
| @@ -4240,7 +4240,7 @@ class SparseApplyAdagrad(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, lr, update_slots=True, use_locking=False): | |||
| validator.check_value_type("lr", lr, [float], self.name) | |||
| validator.check_number_range("lr", lr, float("-inf"), float("inf"), Rel.INC_NEITHER, self.name) | |||
| validator.check_is_float(lr, "lr", self.name) | |||
| validator.check_value_type("update_slots", update_slots, [bool], self.name) | |||
| validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||
| @@ -5142,9 +5142,9 @@ class SparseApplyFtrl(PrimitiveWithCheck): | |||
| validator.check_value_type("l1", l1, [float], self.name) | |||
| validator.check_value_type("l2", l2, [float], self.name) | |||
| validator.check_value_type("lr_power", lr_power, [float], self.name) | |||
| self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name) | |||
| self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name) | |||
| self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name) | |||
| self.lr = validator.check_positive_float(lr, "lr", self.name) | |||
| self.l1 = validator.check_non_negative_float(l1, "l1", self.name) | |||
| self.l2 = validator.check_non_negative_float(l2, "l2", self.name) | |||
| self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) | |||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||
| self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'], | |||
| @@ -5239,9 +5239,9 @@ class SparseApplyFtrlV2(PrimitiveWithInfer): | |||
| validator.check_value_type("l1", l1, [float], self.name) | |||
| validator.check_value_type("l2", l2, [float], self.name) | |||
| validator.check_value_type("lr_power", lr_power, [float], self.name) | |||
| self.lr = validator.check_number_range("lr", lr, 0.0, float("inf"), Rel.INC_NEITHER, self.name) | |||
| self.l1 = validator.check_number_range("l1", l1, 0.0, float("inf"), Rel.INC_LEFT, self.name) | |||
| self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name) | |||
| self.lr = validator.check_positive_float(lr, "lr", self.name) | |||
| self.l1 = validator.check_non_negative_float(l1, "l1", self.name) | |||
| self.l2 = validator.check_non_negative_float(l2, "l2", self.name) | |||
| self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) | |||
| self.l2_shrinkage = validator.check_value_type("l2_shrinkage", l2_shrinkage, [float], self.name) | |||
| self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) | |||
| @@ -5285,7 +5285,7 @@ class Dropout(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, keep_prob=0.5): | |||
| self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0, 1, Rel.INC_RIGHT, self.name) | |||
| self.keep_prob = validator.check_float_range(keep_prob, 0, 1, Rel.INC_RIGHT, "keep_prob", self.name) | |||
| def infer_shape(self, x_shape): | |||
| validator.check_integer("x_shape", len(x_shape), 1, Rel.GE, self.name) | |||
| @@ -5510,7 +5510,7 @@ class BasicLSTMCell(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'): | |||
| self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | |||
| self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name) | |||
| self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name) | |||
| self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | |||
| self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name) | |||
| self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | |||
| @@ -100,8 +100,8 @@ def _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps): | |||
| lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps) | |||
| lr = float(lr_init) + lr_inc * (i + 1) | |||
| else: | |||
| cosine_decay = 0.5 * (1 + math.cos(math.pi * (i-warmup_steps) / decay_steps)) | |||
| lr = (lr_max-lr_end)*cosine_decay + lr_end | |||
| cosine_decay = 0.5 * (1 + math.cos(math.pi * (i - warmup_steps) / decay_steps)) | |||
| lr = (lr_max - lr_end) * cosine_decay + lr_end | |||
| lr_each_step.append(lr) | |||
| return lr_each_step | |||
| @@ -122,7 +122,7 @@ class MySparseGatherV2(PrimitiveWithInfer): | |||
| axis_v = axis['value'] | |||
| params_shp = params['shape'] | |||
| rank = len(params_shp) | |||
| validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) | |||
| validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) | |||
| if axis_v < 0: | |||
| axis_v += rank | |||
| out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] | |||
| @@ -208,10 +208,10 @@ def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): | |||
| validator.check_value_type("beta2", beta2, [float], prim_name) | |||
| validator.check_value_type("eps", eps, [float], prim_name) | |||
| validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) | |||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||
| validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name) | |||
| validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name) | |||
| validator.check_positive_float(eps, "eps", prim_name) | |||
| validator.check_non_negative_float(weight_decay, "weight_decay", prim_name) | |||
| class AdamWeightDecaySparse(Optimizer): | |||
| @@ -14,55 +14,97 @@ | |||
| # ============================================================================ | |||
| """ test checkparameter """ | |||
| import pytest | |||
| from mindspore._checkparam import check_int, check_input_format, Validator, twice | |||
| import numpy as np | |||
| from mindspore._checkparam import check_input_format, Validator, twice, Rel | |||
| kernel_size = 5 | |||
| kernel_size1 = twice(kernel_size) | |||
| assert kernel_size1 == (5, 5) | |||
| def test_check_integer1(): | |||
| with pytest.raises(TypeError): | |||
| Validator.check_integer("input", 0, Rel.GE, "number") | |||
| def test_check_int_1(): | |||
| assert check_int(3) == 3 | |||
| def test_check_integer2(): | |||
| with pytest.raises(ValueError): | |||
| Validator.check_integer(-1, 0, Rel.GE, "number") | |||
| def test_check_integer3(): | |||
| input = np.random.randint(0, 100) | |||
| assert Validator.check_integer(input, 0, Rel.GE, "number") == input | |||
| def check_int_positive_1(): | |||
| with pytest.raises(ValueError): | |||
| Validator.check_positive_int(-1) | |||
| def test_check_int1(): | |||
| input = np.random.randint(-100, 100) | |||
| assert Validator.check_is_int(input) == input | |||
| def test_check_int2(): | |||
| with pytest.raises(TypeError): | |||
| Validator.check_is_int(3.3) | |||
| def test_NCHW1(): | |||
| assert check_input_format("NCHW") == "NCHW" | |||
| def test_check_int3(): | |||
| with pytest.raises(TypeError): | |||
| Validator.check_is_int("str") | |||
| def test_check_int4(): | |||
| with pytest.raises(TypeError): | |||
| Validator.check_is_int(True) | |||
| def test_check_is_int5(): | |||
| with pytest.raises(TypeError): | |||
| Validator.check_is_int(True) | |||
| with pytest.raises(TypeError): | |||
| Validator.check_is_int(False) | |||
| def test_check_positive_int1(): | |||
| input = np.random.randint(0, 100) | |||
| assert Validator.check_positive_int(input) == input | |||
| def test_NCHW3(): | |||
| def test_check_positive_int2(): | |||
| input = np.random.randint(-100, 0) | |||
| with pytest.raises(ValueError): | |||
| check_input_format("rt") | |||
| Validator.check_positive_int(input) | |||
| def test_check_positive_int3(): | |||
| with pytest.raises(ValueError): | |||
| Validator.check_positive_int(3.3) | |||
| def test_check_int_2(): | |||
| def test_check_positive_int4(): | |||
| with pytest.raises(TypeError): | |||
| check_int(3.3) | |||
| Validator.check_positive_int("str") | |||
| def test_check_negative_int1(): | |||
| input = np.random.randint(-100, -1) | |||
| assert Validator.check_negative_int(input) == input | |||
| def test_check_int_3(): | |||
| with pytest.raises(TypeError): | |||
| check_int("str") | |||
| def test_check_negative_int2(): | |||
| input = np.random.randint(0, 100) | |||
| with pytest.raises(ValueError): | |||
| Validator.check_negative_int(input) | |||
| def test_check_negative_int3(): | |||
| with pytest.raises(ValueError): | |||
| Validator.check_negative_int(3.3) | |||
| def test_check_int_4(): | |||
| def test_check_negative_int4(): | |||
| with pytest.raises(TypeError): | |||
| check_int(True) | |||
| Validator.check_negative_int("str") | |||
| def test_check_non_positive_int1(): | |||
| input = np.random.randint(-100, 0) | |||
| assert Validator.check_non_positive_int(input) == input | |||
| def test_check_int_5(): | |||
| check_int(0) | |||
| check_int(1) | |||
| with pytest.raises(TypeError): | |||
| check_int(True) | |||
| with pytest.raises(TypeError): | |||
| check_int(False) | |||
| def test_check_non_positive_int2(): | |||
| input = np.random.randint(1, 100) | |||
| with pytest.raises(ValueError): | |||
| Validator.check_non_positive_int(input) | |||
| def test_check_non_positive_int3(): | |||
| with pytest.raises(ValueError): | |||
| Validator.check_non_positive_int(3.3) | |||
| def test_check_non_positive_int4(): | |||
| with pytest.raises(TypeError): | |||
| Validator.check_non_positive_int("str") | |||
| def test_check_bool_1(): | |||
| assert Validator.check_bool(True) | |||