| @@ -111,6 +111,24 @@ def check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=N | |||||
| return arg_value | return arg_value | ||||
| def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None): | |||||
| """ | |||||
| Checks input value is float type or not. | |||||
| Usage: | |||||
| - number = check_is_number(number, int) | |||||
| - number = check_is_number(number, int, "bias") | |||||
| - number = check_is_number(number, int, "bias", "bias_class") | |||||
| """ | |||||
| prim_name = f'in \'{prim_name}\'' if prim_name else '' | |||||
| arg_name = f'\'{prim_name}\'' if arg_name else 'Input value' | |||||
| if isinstance(arg_value, arg_type): | |||||
| if math.isinf(arg_value) or math.isnan(arg_value): | |||||
| raise ValueError(f'{arg_name} {prim_name} must be legal float, but got `{arg_value}`.') | |||||
| return arg_value | |||||
| raise TypeError(f'{arg_name} {prim_name} must be float, but got `{type(arg_value).__name__}`') | |||||
| class Validator: | class Validator: | ||||
| """validator for checking input parameters""" | """validator for checking input parameters""" | ||||
| @@ -140,6 +158,18 @@ class Validator: | |||||
| f' with type `{type(arg_value).__name__}`.') | f' with type `{type(arg_value).__name__}`.') | ||||
| return arg_value | return arg_value | ||||
| @staticmethod | |||||
| def check_is_int(arg_value, arg_name=None, prim_name=None): | |||||
| """ | |||||
| Checks input value is float type or not. | |||||
| Usage: | |||||
| - number = check_is_int(number, int) | |||||
| - number = check_is_int(number, int, "bias") | |||||
| - number = check_is_int(number, int, "bias", "bias_class") | |||||
| """ | |||||
| check_is_number(arg_value, int, arg_name, prim_name) | |||||
| @staticmethod | @staticmethod | ||||
| def check_positive_int(arg_value, arg_name=None, prim_name=None): | def check_positive_int(arg_value, arg_name=None, prim_name=None): | ||||
| """ | """ | ||||
| @@ -184,6 +214,18 @@ class Validator: | |||||
| """ | """ | ||||
| return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name) | return check_number(arg_value, 0, Rel.GE, int, arg_name, prim_name) | ||||
| @staticmethod | |||||
| def check_is_float(arg_value, arg_name=None, prim_name=None): | |||||
| """ | |||||
| Checks input value is float type or not. | |||||
| Usage: | |||||
| - number = check_is_float(number, int) | |||||
| - number = check_is_float(number, int, "bias") | |||||
| - number = check_is_float(number, int, "bias", "bias_class") | |||||
| """ | |||||
| check_is_number(arg_value, float, arg_name, prim_name) | |||||
| @staticmethod | @staticmethod | ||||
| def check_positive_float(arg_value, arg_name=None, prim_name=None): | def check_positive_float(arg_value, arg_name=None, prim_name=None): | ||||
| """ | """ | ||||
| @@ -453,16 +495,6 @@ class Validator: | |||||
| raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' | raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' | ||||
| f' but got {get_typename(arg_type)}.') | f' but got {get_typename(arg_type)}.') | ||||
| @staticmethod | |||||
| def check_float_legal_value(arg_name, arg_value, prim_name): | |||||
| """Checks whether a legal value of float type""" | |||||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | |||||
| if isinstance(arg_value, float): | |||||
| if math.isinf(arg_value) or math.isnan(arg_value): | |||||
| raise ValueError(f"{msg_prefix} `{arg_name}` must be legal value, but got {arg_value}.") | |||||
| return arg_value | |||||
| raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") | |||||
| @staticmethod | @staticmethod | ||||
| def check_reduce_shape(ori_shape, shape, axis, prim_name): | def check_reduce_shape(ori_shape, shape, axis, prim_name): | ||||
| """Checks whether shape is ori_shape reduced on axis""" | """Checks whether shape is ori_shape reduced on axis""" | ||||
| @@ -53,7 +53,7 @@ def piecewise_constant_lr(milestone, learning_rates): | |||||
| last_item = 0 | last_item = 0 | ||||
| for i, item in enumerate(milestone): | for i, item in enumerate(milestone): | ||||
| validator.check_positive_int(item, f'milestone[{i}]') | validator.check_positive_int(item, f'milestone[{i}]') | ||||
| validator.check_float_legal_value(f'learning_rates[{i}]', learning_rates[i], None) | |||||
| validator.check_is_float(learning_rates[i], f'learning_rates[{i}]') | |||||
| if item < last_item: | if item < last_item: | ||||
| raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]') | raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]') | ||||
| lr += [learning_rates[i]] * (item - last_item) | lr += [learning_rates[i]] * (item - last_item) | ||||
| @@ -67,9 +67,9 @@ def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_e | |||||
| validator.check_positive_int(step_per_epoch, 'step_per_epoch') | validator.check_positive_int(step_per_epoch, 'step_per_epoch') | ||||
| validator.check_positive_int(decay_epoch, 'decay_epoch') | validator.check_positive_int(decay_epoch, 'decay_epoch') | ||||
| validator.check_positive_float(learning_rate, 'learning_rate') | validator.check_positive_float(learning_rate, 'learning_rate') | ||||
| validator.check_float_legal_value('learning_rate', learning_rate, None) | |||||
| validator.check_is_float(learning_rate, 'learning_rate') | |||||
| validator.check_positive_float(decay_rate, 'decay_rate') | validator.check_positive_float(decay_rate, 'decay_rate') | ||||
| validator.check_float_legal_value('decay_rate', decay_rate, None) | |||||
| validator.check_is_float(decay_rate, 'decay_rate') | |||||
| validator.check_value_type('is_stair', is_stair, [bool], None) | validator.check_value_type('is_stair', is_stair, [bool], None) | ||||
| @@ -235,7 +235,7 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): | |||||
| raise TypeError("min_lr must be float.") | raise TypeError("min_lr must be float.") | ||||
| validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None) | validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, None) | ||||
| validator.check_positive_float(max_lr, 'max_lr') | validator.check_positive_float(max_lr, 'max_lr') | ||||
| validator.check_float_legal_value('max_lr', max_lr, None) | |||||
| validator.check_is_float(max_lr, 'max_lr') | |||||
| validator.check_positive_int(total_step, 'total_step') | validator.check_positive_int(total_step, 'total_step') | ||||
| validator.check_positive_int(step_per_epoch, 'step_per_epoch') | validator.check_positive_int(step_per_epoch, 'step_per_epoch') | ||||
| validator.check_positive_int(decay_epoch, 'decay_epoch') | validator.check_positive_int(decay_epoch, 'decay_epoch') | ||||
| @@ -300,12 +300,12 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e | |||||
| [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01] | [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01] | ||||
| """ | """ | ||||
| validator.check_positive_float(learning_rate, 'learning_rate') | validator.check_positive_float(learning_rate, 'learning_rate') | ||||
| validator.check_float_legal_value('learning_rate', learning_rate, None) | |||||
| validator.check_is_float(learning_rate, 'learning_rate') | |||||
| if not isinstance(end_learning_rate, float): | if not isinstance(end_learning_rate, float): | ||||
| raise TypeError("end_learning_rate must be float.") | raise TypeError("end_learning_rate must be float.") | ||||
| validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) | validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, None) | ||||
| validator.check_positive_float(power, 'power') | validator.check_positive_float(power, 'power') | ||||
| validator.check_float_legal_value('power', power, None) | |||||
| validator.check_is_float(power, 'power') | |||||
| validator.check_positive_int(total_step, 'total_step') | validator.check_positive_int(total_step, 'total_step') | ||||
| validator.check_positive_int(step_per_epoch, 'step_per_epoch') | validator.check_positive_int(step_per_epoch, 'step_per_epoch') | ||||
| validator.check_positive_int(decay_epoch, 'decay_epoch') | validator.check_positive_int(decay_epoch, 'decay_epoch') | ||||
| @@ -55,11 +55,11 @@ class _Conv(Cell): | |||||
| self.weight_init = weight_init | self.weight_init = weight_init | ||||
| self.bias_init = bias_init | self.bias_init = bias_init | ||||
| if isinstance(padding, int): | if isinstance(padding, int): | ||||
| Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name) | |||||
| Validator.check_non_negative_int(padding, 'padding', self.cls_name) | |||||
| self.padding = padding | self.padding = padding | ||||
| elif isinstance(padding, tuple): | elif isinstance(padding, tuple): | ||||
| for pad in padding: | for pad in padding: | ||||
| Validator.check_integer('padding item', pad, 0, Rel.GE, self.cls_name) | |||||
| Validator.check_non_negative_int(pad, 'padding item', self.cls_name) | |||||
| self.padding = padding | self.padding = padding | ||||
| else: | else: | ||||
| raise TypeError("padding type must be int/tuple(int) cannot be {}!".format(type(padding))) | raise TypeError("padding type must be int/tuple(int) cannot be {}!".format(type(padding))) | ||||
| @@ -386,7 +386,7 @@ class Conv1d(_Conv): | |||||
| Validator.check_value_type("dilation", dilation, [int], self.cls_name) | Validator.check_value_type("dilation", dilation, [int], self.cls_name) | ||||
| Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name) | Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name) | ||||
| Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name) | Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name) | ||||
| Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name) | |||||
| Validator.check_non_negative_int(padding, 'padding', self.cls_name) | |||||
| Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name) | Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name) | ||||
| kernel_size = (1, kernel_size) | kernel_size = (1, kernel_size) | ||||
| stride = (1, stride) | stride = (1, stride) | ||||
| @@ -705,7 +705,7 @@ class Conv1dTranspose(_Conv): | |||||
| Validator.check_value_type("dilation", dilation, [int], self.cls_name) | Validator.check_value_type("dilation", dilation, [int], self.cls_name) | ||||
| Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name) | Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name) | ||||
| Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name) | Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name) | ||||
| Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name) | |||||
| Validator.check_non_negative_int(padding, 'padding', self.cls_name) | |||||
| Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name) | Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name) | ||||
| kernel_size = (1, kernel_size) | kernel_size = (1, kernel_size) | ||||
| stride = (1, stride) | stride = (1, stride) | ||||
| @@ -46,9 +46,9 @@ class LearningRateSchedule(Cell): | |||||
| def _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, cls_name): | def _check_inputs(learning_rate, decay_rate, decay_steps, is_stair, cls_name): | ||||
| validator.check_positive_int(decay_steps, 'decay_steps', cls_name) | validator.check_positive_int(decay_steps, 'decay_steps', cls_name) | ||||
| validator.check_positive_float(learning_rate, 'learning_rate', cls_name) | validator.check_positive_float(learning_rate, 'learning_rate', cls_name) | ||||
| validator.check_float_legal_value('learning_rate', learning_rate, cls_name) | |||||
| validator.check_is_float(learning_rate, 'learning_rate', cls_name) | |||||
| validator.check_positive_float(decay_rate, 'decay_rate', cls_name) | validator.check_positive_float(decay_rate, 'decay_rate', cls_name) | ||||
| validator.check_float_legal_value('decay_rate', decay_rate, cls_name) | |||||
| validator.check_is_float(decay_rate, 'decay_rate', cls_name) | |||||
| validator.check_value_type('is_stair', is_stair, [bool], cls_name) | validator.check_value_type('is_stair', is_stair, [bool], cls_name) | ||||
| @@ -256,7 +256,7 @@ class CosineDecayLR(LearningRateSchedule): | |||||
| raise TypeError("min_lr must be float.") | raise TypeError("min_lr must be float.") | ||||
| validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | validator.check_number_range("min_lr", min_lr, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | ||||
| validator.check_positive_float(max_lr, 'max_lr', self.cls_name) | validator.check_positive_float(max_lr, 'max_lr', self.cls_name) | ||||
| validator.check_float_legal_value('max_lr', max_lr, self.cls_name) | |||||
| validator.check_is_float(max_lr, 'max_lr', self.cls_name) | |||||
| validator.check_positive_int(decay_steps, "decay_steps", self.cls_name) | validator.check_positive_int(decay_steps, "decay_steps", self.cls_name) | ||||
| if min_lr >= max_lr: | if min_lr >= max_lr: | ||||
| raise ValueError('`max_lr` should be greater than `min_lr`.') | raise ValueError('`max_lr` should be greater than `min_lr`.') | ||||
| @@ -319,7 +319,7 @@ class PolynomialDecayLR(LearningRateSchedule): | |||||
| def __init__(self, learning_rate, end_learning_rate, decay_steps, power, update_decay_steps=False): | def __init__(self, learning_rate, end_learning_rate, decay_steps, power, update_decay_steps=False): | ||||
| super(PolynomialDecayLR, self).__init__() | super(PolynomialDecayLR, self).__init__() | ||||
| validator.check_positive_float(learning_rate, 'learning_rate') | validator.check_positive_float(learning_rate, 'learning_rate') | ||||
| validator.check_float_legal_value('learning_rate', learning_rate, None) | |||||
| validator.check_is_float(learning_rate, 'learning_rate') | |||||
| if not isinstance(end_learning_rate, float): | if not isinstance(end_learning_rate, float): | ||||
| raise TypeError("end_learning_rate must be float.") | raise TypeError("end_learning_rate must be float.") | ||||
| validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, | validator.check_number_range("end_learning_rate", end_learning_rate, 0.0, float("inf"), Rel.INC_LEFT, | ||||
| @@ -327,7 +327,7 @@ class PolynomialDecayLR(LearningRateSchedule): | |||||
| validator.check_positive_int(decay_steps, 'decay_steps', self.cls_name) | validator.check_positive_int(decay_steps, 'decay_steps', self.cls_name) | ||||
| validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name) | validator.check_value_type('update_decay_steps', update_decay_steps, [bool], self.cls_name) | ||||
| validator.check_positive_float(power, 'power', self.cls_name) | validator.check_positive_float(power, 'power', self.cls_name) | ||||
| validator.check_float_legal_value('power', power, self.cls_name) | |||||
| validator.check_is_float(power, 'power', self.cls_name) | |||||
| self.decay_steps = decay_steps | self.decay_steps = decay_steps | ||||
| self.start_learning_rate = learning_rate | self.start_learning_rate = learning_rate | ||||
| @@ -17,7 +17,6 @@ from mindspore import context | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.nn.cell import Cell | from mindspore.nn.cell import Cell | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from mindspore._checkparam import Rel | |||||
| from mindspore.common import get_seed | from mindspore.common import get_seed | ||||
| from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\ | from ._utils.utils import raise_none_error, cast_to_tensor, set_param_type, cast_type_for_device,\ | ||||
| raise_not_implemented_util | raise_not_implemented_util | ||||
| @@ -64,7 +63,7 @@ class Distribution(Cell): | |||||
| if seed is None: | if seed is None: | ||||
| seed = 0 | seed = 0 | ||||
| validator.check_value_type('name', name, [str], type(self).__name__) | validator.check_value_type('name', name, [str], type(self).__name__) | ||||
| validator.check_integer('seed', seed, 0, Rel.GE, name) | |||||
| validator.check_non_negative_int(seed, 'seed', name) | |||||
| self._name = name | self._name = name | ||||
| self._seed = seed | self._seed = seed | ||||
| @@ -141,7 +141,7 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer): | |||||
| if self.is_ascend: | if self.is_ascend: | ||||
| self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name) | self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name) | ||||
| else: | else: | ||||
| self.channel_axis = validator.check_integer('channel_axis', channel_axis, 0, Rel.GE, self.name) | |||||
| self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name) | |||||
| self.init_prim_io_names( | self.init_prim_io_names( | ||||
| inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up']) | inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up']) | ||||
| @@ -226,10 +226,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer): | |||||
| 'training', training, (bool,), self.name) | 'training', training, (bool,), self.name) | ||||
| self.ema_decay = validator.check_number_range( | self.ema_decay = validator.check_number_range( | ||||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | ||||
| self.num_bits = validator.check_integer( | |||||
| 'num_bits', num_bits, 0, Rel.GT, self.name) | |||||
| self.quant_delay = validator.check_integer( | |||||
| 'quant_delay', quant_delay, 0, Rel.GE, self.name) | |||||
| self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) | |||||
| self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name) | |||||
| self.init_prim_io_names(inputs=['x', 'min', 'max'], | self.init_prim_io_names(inputs=['x', 'min', 'max'], | ||||
| outputs=['out']) | outputs=['out']) | ||||
| @@ -275,8 +273,7 @@ class FakeQuantPerLayerGrad(PrimitiveWithInfer): | |||||
| raise ValueError( | raise ValueError( | ||||
| f"For '{self.name}' attr \'num_bits\' is not support.") | f"For '{self.name}' attr \'num_bits\' is not support.") | ||||
| self.num_bits = validator.check_integer( | |||||
| 'num_bits', num_bits, 0, Rel.GT, self.name) | |||||
| self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) | |||||
| self.quant_delay = validator.check_value_type( | self.quant_delay = validator.check_value_type( | ||||
| 'quant_delay', quant_delay, (int,), self.name) | 'quant_delay', quant_delay, (int,), self.name) | ||||
| self.symmetric = validator.check_value_type( | self.symmetric = validator.check_value_type( | ||||
| @@ -371,14 +368,12 @@ class FakeQuantPerChannel(PrimitiveWithInfer): | |||||
| 'training', training, (bool,), self.name) | 'training', training, (bool,), self.name) | ||||
| self.ema_decay = validator.check_number_range( | self.ema_decay = validator.check_number_range( | ||||
| 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) | ||||
| self.num_bits = validator.check_integer( | |||||
| 'num_bits', num_bits, 0, Rel.GT, self.name) | |||||
| self.quant_delay = validator.check_integer( | |||||
| 'quant_delay', quant_delay, 0, Rel.GE, self.name) | |||||
| self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) | |||||
| self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name) | |||||
| if self.is_ascend: | if self.is_ascend: | ||||
| self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name) | self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name) | ||||
| else: | else: | ||||
| self.channel_axis = validator.check_integer('channel_axis', channel_axis, 0, Rel.GE, self.name) | |||||
| self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name) | |||||
| self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) | self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) | ||||
| def infer_shape(self, x_shape, min_shape, max_shape): | def infer_shape(self, x_shape, min_shape, max_shape): | ||||
| @@ -433,16 +428,14 @@ class FakeQuantPerChannelGrad(PrimitiveWithInfer): | |||||
| raise ValueError( | raise ValueError( | ||||
| f"For '{self.name}' attr \'num_bits\' is not support.") | f"For '{self.name}' attr \'num_bits\' is not support.") | ||||
| self.num_bits = validator.check_integer( | |||||
| 'num_bits', num_bits, 0, Rel.GT, self.name) | |||||
| self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) | |||||
| self.quant_delay = validator.check_value_type( | self.quant_delay = validator.check_value_type( | ||||
| 'quant_delay', quant_delay, (int,), self.name) | 'quant_delay', quant_delay, (int,), self.name) | ||||
| self.symmetric = validator.check_value_type( | self.symmetric = validator.check_value_type( | ||||
| 'symmetric', symmetric, (bool,), self.name) | 'symmetric', symmetric, (bool,), self.name) | ||||
| self.narrow_range = validator.check_value_type( | self.narrow_range = validator.check_value_type( | ||||
| 'narrow_range', narrow_range, (bool,), self.name) | 'narrow_range', narrow_range, (bool,), self.name) | ||||
| self.channel_axis = validator.check_integer( | |||||
| 'channel axis', channel_axis, 0, Rel.GE, self.name) | |||||
| self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name) | |||||
| self.init_prim_io_names( | self.init_prim_io_names( | ||||
| inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) | inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) | ||||
| @@ -516,7 +516,7 @@ class Im2Col(PrimitiveWithInfer): | |||||
| self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name) | self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name) | ||||
| self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) | self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) | ||||
| if self.pad_mode == 'pad': | if self.pad_mode == 'pad': | ||||
| validator.check_integer('pad', self.pad, 0, Rel.GE, self.name) | |||||
| validator.check_non_negative_int(self.pad, 'pad', self.name) | |||||
| self.add_prim_attr('data_format', "NCHW") | self.add_prim_attr('data_format', "NCHW") | ||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| @@ -763,7 +763,7 @@ class Split(PrimitiveWithInfer): | |||||
| x_shape = list(x['shape']) | x_shape = list(x['shape']) | ||||
| dim = len(x_shape) | dim = len(x_shape) | ||||
| validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) | validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) | ||||
| validator.check_integer("output_num", self.output_num, 0, Rel.GT, self.name) | |||||
| validator.check_positive_int(self.output_num, "output_num", self.name) | |||||
| output_valid_check = x_shape[self.axis] % self.output_num | output_valid_check = x_shape[self.axis] % self.output_num | ||||
| if output_valid_check != 0: | if output_valid_check != 0: | ||||
| raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by" | raise ValueError(f"x_shape[{self.axis}] {x_shape[self.axis]} must be divide exactly by" | ||||
| @@ -846,7 +846,7 @@ class TruncatedNormal(PrimitiveWithInfer): | |||||
| shape_value = shape['value'] | shape_value = shape['value'] | ||||
| validator.check_value_type("shape", shape_value, [tuple], self.name) | validator.check_value_type("shape", shape_value, [tuple], self.name) | ||||
| for i, value in enumerate(shape_value): | for i, value in enumerate(shape_value): | ||||
| validator.check_integer(f'{i}th value of shape', value, 0, Rel.GT, self.name) | |||||
| validator.check_positive_int(value, f'{i}th value of shape', self.name) | |||||
| out = {'shape': shape_value, | out = {'shape': shape_value, | ||||
| 'dtype': mstype.tensor_type(self.dtype), | 'dtype': mstype.tensor_type(self.dtype), | ||||
| 'value': None} | 'value': None} | ||||
| @@ -2180,13 +2180,13 @@ class StridedSlice(PrimitiveWithInfer): | |||||
| shrink_axis_mask=0): | shrink_axis_mask=0): | ||||
| """Initialize StrideSlice""" | """Initialize StrideSlice""" | ||||
| self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output']) | self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output']) | ||||
| validator.check_integer('begin_mask', begin_mask, 0, Rel.GE, self.name) | |||||
| validator.check_integer('end_mask', end_mask, 0, Rel.GE, self.name) | |||||
| validator.check_integer('ellipsis_mask', ellipsis_mask, 0, Rel.GE, self.name) | |||||
| validator.check_non_negative_int(begin_mask, 'begin_mask', self.name) | |||||
| validator.check_non_negative_int(end_mask, 'end_mask', self.name) | |||||
| validator.check_non_negative_int(ellipsis_mask, 'ellipsis_mask', self.name) | |||||
| if len(tuple(filter(lambda x: x == '1', bin(ellipsis_mask)[-1:1:-1]))) > 1: | if len(tuple(filter(lambda x: x == '1', bin(ellipsis_mask)[-1:1:-1]))) > 1: | ||||
| raise ValueError(f"For '{self.name}', only support one ellipsis in the index, but got {end_mask}.") | raise ValueError(f"For '{self.name}', only support one ellipsis in the index, but got {end_mask}.") | ||||
| validator.check_integer('new_axis_mask', new_axis_mask, 0, Rel.GE, self.name) | |||||
| validator.check_integer('shrink_axis_mask', shrink_axis_mask, 0, Rel.GE, self.name) | |||||
| validator.check_non_negative_int(new_axis_mask, 'new_axis_mask', self.name) | |||||
| validator.check_non_negative_int(shrink_axis_mask, 'shrink_axis_mask', self.name) | |||||
| def __infer__(self, x, begin, end, strides): | def __infer__(self, x, begin, end, strides): | ||||
| begin_v, end_v, strides_v = begin['value'], end['value'], strides['value'] | begin_v, end_v, strides_v = begin['value'], end['value'], strides['value'] | ||||
| @@ -2507,7 +2507,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer): | |||||
| validator.check_value_type("align_corners", align_corners, [bool], self.name) | validator.check_value_type("align_corners", align_corners, [bool], self.name) | ||||
| validator.check_integer("length of size", len(size), 2, Rel.EQ, self.name) | validator.check_integer("length of size", len(size), 2, Rel.EQ, self.name) | ||||
| for i, value in enumerate(size): | for i, value in enumerate(size): | ||||
| validator.check_integer(f'{i}th value of size', value, 0, Rel.GE, self.name) | |||||
| validator.check_non_negative_int(value, f'{i}th value of size', self.name) | |||||
| self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) | self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) | ||||
| def infer_shape(self, x): | def infer_shape(self, x): | ||||
| @@ -3176,7 +3176,7 @@ class SpaceToBatch(PrimitiveWithInfer): | |||||
| self.block_size = block_size | self.block_size = block_size | ||||
| validator.check('paddings shape', np.array(paddings).shape, '', (2, 2), Rel.EQ, self.name) | validator.check('paddings shape', np.array(paddings).shape, '', (2, 2), Rel.EQ, self.name) | ||||
| for elem in itertools.chain(*paddings): | for elem in itertools.chain(*paddings): | ||||
| validator.check_integer('paddings element', elem, 0, Rel.GE, self.name) | |||||
| validator.check_non_negative_int(elem, 'paddings element', self.name) | |||||
| validator.check_value_type('paddings element', elem, [int], self.name) | validator.check_value_type('paddings element', elem, [int], self.name) | ||||
| self.paddings = paddings | self.paddings = paddings | ||||
| @@ -3248,7 +3248,7 @@ class BatchToSpace(PrimitiveWithInfer): | |||||
| validator.check_value_type('crops type', crops, [list, tuple], self.name) | validator.check_value_type('crops type', crops, [list, tuple], self.name) | ||||
| validator.check('crops shape', np.array(crops).shape, '', (2, 2)) | validator.check('crops shape', np.array(crops).shape, '', (2, 2)) | ||||
| for elem in itertools.chain(*crops): | for elem in itertools.chain(*crops): | ||||
| validator.check_integer('crops element', elem, 0, Rel.GE, self.name) | |||||
| validator.check_non_negative_int(elem, 'crops element', self.name) | |||||
| validator.check_value_type('crops element', elem, [int], self.name) | validator.check_value_type('crops element', elem, [int], self.name) | ||||
| self.crops = crops | self.crops = crops | ||||
| @@ -3333,7 +3333,7 @@ class SpaceToBatchND(PrimitiveWithInfer): | |||||
| validator.check('paddings length', len(paddings), '', 2, Rel.EQ, self.name) | validator.check('paddings length', len(paddings), '', 2, Rel.EQ, self.name) | ||||
| validator.check('paddings shape', np.array(paddings).shape, '', (block_rank, 2), Rel.EQ, self.name) | validator.check('paddings shape', np.array(paddings).shape, '', (block_rank, 2), Rel.EQ, self.name) | ||||
| for elem in itertools.chain(*paddings): | for elem in itertools.chain(*paddings): | ||||
| validator.check_integer('paddings element', elem, 0, Rel.GE, self.name) | |||||
| validator.check_non_negative_int(elem, 'paddings element', self.name) | |||||
| validator.check_value_type('paddings element', elem, [int], self.name) | validator.check_value_type('paddings element', elem, [int], self.name) | ||||
| self.paddings = paddings | self.paddings = paddings | ||||
| block_shape_append = [1] + list(self.block_shape) | block_shape_append = [1] + list(self.block_shape) | ||||
| @@ -3426,7 +3426,7 @@ class BatchToSpaceND(PrimitiveWithInfer): | |||||
| validator.check('crops length', len(crops), '', 2, Rel.EQ, self.name) | validator.check('crops length', len(crops), '', 2, Rel.EQ, self.name) | ||||
| validator.check('crops shape', np.array(crops).shape, '', (block_rank, 2), Rel.EQ, self.name) | validator.check('crops shape', np.array(crops).shape, '', (block_rank, 2), Rel.EQ, self.name) | ||||
| for elem in itertools.chain(*crops): | for elem in itertools.chain(*crops): | ||||
| validator.check_integer('crops element', elem, 0, Rel.GE, self.name) | |||||
| validator.check_non_negative_int(elem, 'crops element', self.name) | |||||
| validator.check_value_type('crops element', elem, [int], self.name) | validator.check_value_type('crops element', elem, [int], self.name) | ||||
| self.crops = crops | self.crops = crops | ||||
| block_shape_append = [1] + list(self.block_shape) | block_shape_append = [1] + list(self.block_shape) | ||||
| @@ -3019,7 +3019,7 @@ class NMSWithMask(PrimitiveWithInfer): | |||||
| def infer_shape(self, bboxes_shape): | def infer_shape(self, bboxes_shape): | ||||
| cls_name = self.name | cls_name = self.name | ||||
| validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) | validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) | ||||
| validator.check_integer("bboxes.shape[0]", bboxes_shape[0], 0, Rel.GT, cls_name) | |||||
| validator.check_positive_int(bboxes_shape[0], "bboxes.shape[0]", cls_name) | |||||
| validator.check_integer("bboxes.shape[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) | validator.check_integer("bboxes.shape[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) | ||||
| num = bboxes_shape[0] | num = bboxes_shape[0] | ||||
| return (bboxes_shape, (num,), (num,)) | return (bboxes_shape, (num,), (num,)) | ||||
| @@ -1001,7 +1001,7 @@ class Conv2D(PrimitiveWithInfer): | |||||
| raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.") | raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.") | ||||
| if self.pad_mode == 'pad': | if self.pad_mode == 'pad': | ||||
| for item in pad: | for item in pad: | ||||
| validator.check_integer('pad item', item, 0, Rel.GE, self.name) | |||||
| validator.check_non_negative_int(item, 'pad item', self.name) | |||||
| self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) | self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) | ||||
| self.add_prim_attr('data_format', "NCHW") | self.add_prim_attr('data_format', "NCHW") | ||||
| @@ -1139,7 +1139,7 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): | |||||
| raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.") | raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.") | ||||
| if self.pad_mode == 'pad': | if self.pad_mode == 'pad': | ||||
| for item in pad: | for item in pad: | ||||
| validator.check_integer('pad item', item, 0, Rel.GE, self.name) | |||||
| validator.check_non_negative_int(item, 'pad item', self.name) | |||||
| self.mode = validator.check_integer("mode", mode, 3, Rel.EQ, self.name) | self.mode = validator.check_integer("mode", mode, 3, Rel.EQ, self.name) | ||||
| self.add_prim_attr('data_format', "NCHW") | self.add_prim_attr('data_format', "NCHW") | ||||
| self.channel_multiplier = validator.check_positive_int(channel_multiplier, "channel_multiplier", self.name) | self.channel_multiplier = validator.check_positive_int(channel_multiplier, "channel_multiplier", self.name) | ||||
| @@ -1525,7 +1525,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer): | |||||
| raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.") | raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.") | ||||
| if self.pad_mode == 'pad': | if self.pad_mode == 'pad': | ||||
| for item in pad: | for item in pad: | ||||
| validator.check_integer('pad item', item, 0, Rel.GE, self.name) | |||||
| validator.check_non_negative_int(item, 'pad item', self.name) | |||||
| pad_mode = pad_mode.upper() | pad_mode = pad_mode.upper() | ||||
| self.add_prim_attr('pad_mode', pad_mode) | self.add_prim_attr('pad_mode', pad_mode) | ||||
| @@ -1534,7 +1534,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer): | |||||
| self.add_prim_attr('data_format', "NCHW") | self.add_prim_attr('data_format', "NCHW") | ||||
| if pad_list: | if pad_list: | ||||
| for x in pad_list: | for x in pad_list: | ||||
| validator.check_integer('element of pad_list', x, 0, Rel.GE, self.name) | |||||
| validator.check_non_negative_int(x, 'element of pad_list', self.name) | |||||
| self.pad_list = pad_list | self.pad_list = pad_list | ||||
| def __infer__(self, doutput, w, x_size): | def __infer__(self, doutput, w, x_size): | ||||
| @@ -2568,7 +2568,7 @@ class OneHot(PrimitiveWithInfer): | |||||
| indices_shp = indices['shape'] | indices_shp = indices['shape'] | ||||
| validator.check_int_range("axis", self.axis, -1, len(indices_shp), Rel.INC_BOTH, self.name) | validator.check_int_range("axis", self.axis, -1, len(indices_shp), Rel.INC_BOTH, self.name) | ||||
| depth_val = depth['value'] | depth_val = depth['value'] | ||||
| validator.check_integer("depth", depth_val, 0, Rel.GE, self.name) | |||||
| validator.check_non_negative_int(depth_val, "depth", self.name) | |||||
| # create new dimension at end if self.axis is -1 | # create new dimension at end if self.axis is -1 | ||||
| _ = indices_shp.insert(self.axis, depth_val) if self.axis >= 0 else indices_shp.append(depth_val) | _ = indices_shp.insert(self.axis, depth_val) if self.axis >= 0 else indices_shp.append(depth_val) | ||||
| @@ -5722,7 +5722,7 @@ class LRN(PrimitiveWithInfer): | |||||
| validator.check_value_type("beta", beta, [float], self.name) | validator.check_value_type("beta", beta, [float], self.name) | ||||
| validator.check_value_type("norm_region", norm_region, [str], self.name) | validator.check_value_type("norm_region", norm_region, [str], self.name) | ||||
| validator.check_string(norm_region, ['ACROSS_CHANNELS'], 'norm_region', self.name) | validator.check_string(norm_region, ['ACROSS_CHANNELS'], 'norm_region', self.name) | ||||
| validator.check_integer("depth_radius", depth_radius, 0, Rel.GE, self.name) | |||||
| validator.check_non_negative_int(depth_radius, "depth_radius", self.name) | |||||
| def infer_dtype(self, x_dtype): | def infer_dtype(self, x_dtype): | ||||
| validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32,), self.name) | validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32,), self.name) | ||||
| @@ -44,8 +44,8 @@ class StandardNormal(PrimitiveWithInfer): | |||||
| def __init__(self, seed=0, seed2=0): | def __init__(self, seed=0, seed2=0): | ||||
| """Initialize StandardNormal""" | """Initialize StandardNormal""" | ||||
| self.init_prim_io_names(inputs=['shape'], outputs=['output']) | self.init_prim_io_names(inputs=['shape'], outputs=['output']) | ||||
| Validator.check_integer("seed", seed, 0, Rel.GE, self.name) | |||||
| Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) | |||||
| Validator.check_non_negative_int(seed, "seed", self.name) | |||||
| Validator.check_non_negative_int(seed2, "seed2", self.name) | |||||
| def __infer__(self, shape): | def __infer__(self, shape): | ||||
| shape_v = shape["value"] | shape_v = shape["value"] | ||||
| @@ -141,8 +141,8 @@ class Gamma(PrimitiveWithInfer): | |||||
| def __init__(self, seed=0, seed2=0): | def __init__(self, seed=0, seed2=0): | ||||
| """Initialize Gamma""" | """Initialize Gamma""" | ||||
| self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output']) | self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output']) | ||||
| Validator.check_integer("seed", seed, 0, Rel.GE, self.name) | |||||
| Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) | |||||
| Validator.check_non_negative_int(seed, "seed", self.name) | |||||
| Validator.check_non_negative_int(seed2, "seed2", self.name) | |||||
| def __infer__(self, shape, alpha, beta): | def __infer__(self, shape, alpha, beta): | ||||
| shape_v = shape["value"] | shape_v = shape["value"] | ||||
| @@ -193,8 +193,8 @@ class Poisson(PrimitiveWithInfer): | |||||
| def __init__(self, seed=0, seed2=0): | def __init__(self, seed=0, seed2=0): | ||||
| """Initialize Poisson""" | """Initialize Poisson""" | ||||
| self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output']) | self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output']) | ||||
| Validator.check_integer("seed", seed, 0, Rel.GE, self.name) | |||||
| Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) | |||||
| Validator.check_non_negative_int(seed, "seed", self.name) | |||||
| Validator.check_non_negative_int(seed2, "seed2", self.name) | |||||
| def __infer__(self, shape, mean): | def __infer__(self, shape, mean): | ||||
| shape_v = shape["value"] | shape_v = shape["value"] | ||||
| @@ -249,8 +249,8 @@ class UniformInt(PrimitiveWithInfer): | |||||
| def __init__(self, seed=0, seed2=0): | def __init__(self, seed=0, seed2=0): | ||||
| """Initialize UniformInt""" | """Initialize UniformInt""" | ||||
| self.init_prim_io_names(inputs=['shape', 'minval', 'maxval'], outputs=['output']) | self.init_prim_io_names(inputs=['shape', 'minval', 'maxval'], outputs=['output']) | ||||
| Validator.check_integer("seed", seed, 0, Rel.GE, self.name) | |||||
| Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) | |||||
| Validator.check_non_negative_int(seed, "seed", self.name) | |||||
| Validator.check_non_negative_int(seed2, "seed2", self.name) | |||||
| def __infer__(self, shape, minval, maxval): | def __infer__(self, shape, minval, maxval): | ||||
| shape_v = shape["value"] | shape_v = shape["value"] | ||||
| @@ -296,8 +296,8 @@ class UniformReal(PrimitiveWithInfer): | |||||
| def __init__(self, seed=0, seed2=0): | def __init__(self, seed=0, seed2=0): | ||||
| """Initialize UniformReal""" | """Initialize UniformReal""" | ||||
| self.init_prim_io_names(inputs=['shape'], outputs=['output']) | self.init_prim_io_names(inputs=['shape'], outputs=['output']) | ||||
| Validator.check_integer("seed", seed, 0, Rel.GE, self.name) | |||||
| Validator.check_integer("seed2", seed2, 0, Rel.GE, self.name) | |||||
| Validator.check_non_negative_int(seed, "seed", self.name) | |||||
| Validator.check_non_negative_int(seed2, "seed2", self.name) | |||||
| def __infer__(self, shape): | def __infer__(self, shape): | ||||
| shape_v = shape["value"] | shape_v = shape["value"] | ||||
| @@ -449,7 +449,7 @@ class Multinomial(PrimitiveWithInfer): | |||||
| def __init__(self, seed=0): | def __init__(self, seed=0): | ||||
| """init""" | """init""" | ||||
| Validator.check_value_type("seed", seed, [int], self.name) | Validator.check_value_type("seed", seed, [int], self.name) | ||||
| Validator.check_integer("seed", seed, 0, Rel.GE, self.name) | |||||
| Validator.check_non_negative_int(seed, "seed", self.name) | |||||
| self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output']) | self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output']) | ||||
| def __infer__(self, inputs, num_samples): | def __infer__(self, inputs, num_samples): | ||||