| @@ -17,7 +17,7 @@ import re | |||||
| from enum import Enum | from enum import Enum | ||||
| from functools import reduce | from functools import reduce | ||||
| from itertools import repeat | from itertools import repeat | ||||
| from collections import Iterable | |||||
| from collections.abc import Iterable | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| @@ -98,7 +98,7 @@ class Validator: | |||||
| """validator for checking input parameters""" | """validator for checking input parameters""" | ||||
| @staticmethod | @staticmethod | ||||
| def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None): | |||||
| def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError): | |||||
| """ | """ | ||||
| Method for judging relation between two int values or list/tuple made up of ints. | Method for judging relation between two int values or list/tuple made up of ints. | ||||
| @@ -108,8 +108,8 @@ class Validator: | |||||
| rel_fn = Rel.get_fns(rel) | rel_fn = Rel.get_fns(rel) | ||||
| if not rel_fn(arg_value, value): | if not rel_fn(arg_value, value): | ||||
| rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') | rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') | ||||
| msg_prefix = f'For {prim_name} the' if prim_name else "The" | |||||
| raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') | |||||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | |||||
| raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') | |||||
| @staticmethod | @staticmethod | ||||
| def check_integer(arg_name, arg_value, value, rel, prim_name): | def check_integer(arg_name, arg_value, value, rel, prim_name): | ||||
| @@ -118,8 +118,17 @@ class Validator: | |||||
| type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) | type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) | ||||
| if type_mismatch or not rel_fn(arg_value, value): | if type_mismatch or not rel_fn(arg_value, value): | ||||
| rel_str = Rel.get_strs(rel).format(value) | rel_str = Rel.get_strs(rel).format(value) | ||||
| raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},' | |||||
| f' but got {arg_value}.') | |||||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | |||||
| raise ValueError(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_number(arg_name, arg_value, value, rel, prim_name): | |||||
| """Integer value judgment.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| if not rel_fn(arg_value, value): | |||||
| rel_str = Rel.get_strs(rel).format(value) | |||||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.') | |||||
| return arg_value | return arg_value | ||||
| @staticmethod | @staticmethod | ||||
| @@ -133,9 +142,46 @@ class Validator: | |||||
| f' but got {arg_value}.') | f' but got {arg_value}.') | ||||
| return arg_value | return arg_value | ||||
| @staticmethod | |||||
| def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): | |||||
| """Method for checking whether a numeric value is in some range.""" | |||||
| rel_fn = Rel.get_fns(rel) | |||||
| if not rel_fn(arg_value, lower_limit, upper_limit): | |||||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.') | |||||
| return arg_value | |||||
| @staticmethod | |||||
| def check_string(arg_name, arg_value, valid_values, prim_name): | |||||
| """Checks whether a string is in some value list""" | |||||
| if isinstance(arg_value, str) and arg_value in valid_values: | |||||
| return arg_value | |||||
| if len(valid_values) == 1: | |||||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be {valid_values[0]},' | |||||
| f' but got {arg_value}.') | |||||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be one of {valid_values},' | |||||
| f' but got {arg_value}.') | |||||
| @staticmethod | |||||
| def check_pad_value_by_mode(pad_mode, padding, prim_name): | |||||
| """Validates value of padding according to pad_mode""" | |||||
| if pad_mode != 'pad' and padding != 0: | |||||
| raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.") | |||||
| return padding | |||||
| @staticmethod | |||||
| def check_float_positive(arg_name, arg_value, prim_name): | |||||
| """Float type judgment.""" | |||||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | |||||
| if isinstance(arg_value, float): | |||||
| if arg_value > 0: | |||||
| return arg_value | |||||
| raise ValueError(f"{msg_prefix} `{arg_name}` must be positive, but got {arg_value}.") | |||||
| raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") | |||||
| @staticmethod | @staticmethod | ||||
| def check_subclass(arg_name, type_, template_type, prim_name): | def check_subclass(arg_name, type_, template_type, prim_name): | ||||
| """Check whether some type is sublcass of another type""" | |||||
| """Checks whether some type is sublcass of another type""" | |||||
| if not isinstance(template_type, Iterable): | if not isinstance(template_type, Iterable): | ||||
| template_type = (template_type,) | template_type = (template_type,) | ||||
| if not any([mstype.issubclass_(type_, x) for x in template_type]): | if not any([mstype.issubclass_(type_, x) for x in template_type]): | ||||
| @@ -143,16 +189,44 @@ class Validator: | |||||
| raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' | raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' | ||||
| f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') | f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') | ||||
| @staticmethod | |||||
| def check_const_input(arg_name, arg_value, prim_name): | |||||
| """Check valid value.""" | |||||
| if arg_value is None: | |||||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.') | |||||
| @staticmethod | |||||
| def check_scalar_type_same(args, valid_values, prim_name): | |||||
| """check whether the types of inputs are the same.""" | |||||
| def _check_tensor_type(arg): | |||||
| arg_key, arg_val = arg | |||||
| elem_type = arg_val | |||||
| if not elem_type in valid_values: | |||||
| raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {valid_values},' | |||||
| f' but `{arg_key}` is {elem_type}.') | |||||
| return (arg_key, elem_type) | |||||
| def _check_types_same(arg1, arg2): | |||||
| arg1_name, arg1_type = arg1 | |||||
| arg2_name, arg2_type = arg2 | |||||
| if arg1_type != arg2_type: | |||||
| raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' | |||||
| f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') | |||||
| return arg1 | |||||
| elem_types = map(_check_tensor_type, args.items()) | |||||
| reduce(_check_types_same, elem_types) | |||||
| @staticmethod | @staticmethod | ||||
| def check_tensor_type_same(args, valid_values, prim_name): | def check_tensor_type_same(args, valid_values, prim_name): | ||||
| """check whether the element types of input tensors are the same.""" | |||||
| """Checks whether the element types of input tensors are the same.""" | |||||
| def _check_tensor_type(arg): | def _check_tensor_type(arg): | ||||
| arg_key, arg_val = arg | arg_key, arg_val = arg | ||||
| Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name) | Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name) | ||||
| elem_type = arg_val.element_type() | elem_type = arg_val.element_type() | ||||
| if not elem_type in valid_values: | if not elem_type in valid_values: | ||||
| raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},' | raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},' | ||||
| f' but `{arg_key}` is {elem_type}.') | |||||
| f' but element type of `{arg_key}` is {elem_type}.') | |||||
| return (arg_key, elem_type) | return (arg_key, elem_type) | ||||
| def _check_types_same(arg1, arg2): | def _check_types_same(arg1, arg2): | ||||
| @@ -168,8 +242,13 @@ class Validator: | |||||
| @staticmethod | @staticmethod | ||||
| def check_scalar_or_tensor_type_same(args, valid_values, prim_name): | |||||
| """check whether the types of inputs are the same. if the input args are tensors, check their element types""" | |||||
| def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): | |||||
| """ | |||||
| Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. | |||||
| If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. | |||||
| """ | |||||
| def _check_argument_type(arg): | def _check_argument_type(arg): | ||||
| arg_key, arg_val = arg | arg_key, arg_val = arg | ||||
| if isinstance(arg_val, type(mstype.tensor)): | if isinstance(arg_val, type(mstype.tensor)): | ||||
| @@ -188,6 +267,9 @@ class Validator: | |||||
| arg2_type = arg2_type.element_type() | arg2_type = arg2_type.element_type() | ||||
| elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))): | elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))): | ||||
| pass | pass | ||||
| elif allow_mix: | |||||
| arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type | |||||
| arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type | |||||
| else: | else: | ||||
| excp_flag = True | excp_flag = True | ||||
| @@ -199,13 +281,14 @@ class Validator: | |||||
| @staticmethod | @staticmethod | ||||
| def check_value_type(arg_name, arg_value, valid_types, prim_name): | def check_value_type(arg_name, arg_value, valid_types, prim_name): | ||||
| """Check whether a values is instance of some types.""" | |||||
| """Checks whether a value is instance of some types.""" | |||||
| valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) | |||||
| def raise_error_msg(): | def raise_error_msg(): | ||||
| """func for raising error message when check failed""" | """func for raising error message when check failed""" | ||||
| type_names = [t.__name__ for t in valid_types] | type_names = [t.__name__ for t in valid_types] | ||||
| num_types = len(valid_types) | num_types = len(valid_types) | ||||
| raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be ' | |||||
| f'{"one of " if num_types > 1 else ""}' | |||||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' | |||||
| raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' | |||||
| f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | ||||
| # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and | # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and | ||||
| @@ -216,6 +299,23 @@ class Validator: | |||||
| return arg_value | return arg_value | ||||
| raise_error_msg() | raise_error_msg() | ||||
| @staticmethod | |||||
| def check_type_name(arg_name, arg_type, valid_types, prim_name): | |||||
| """Checks whether a type in some specified types""" | |||||
| valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) | |||||
| def get_typename(t): | |||||
| return t.__name__ if hasattr(t, '__name__') else str(t) | |||||
| if arg_type in valid_types: | |||||
| return arg_type | |||||
| type_names = [get_typename(t) for t in valid_types] | |||||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' | |||||
| if len(valid_types) == 1: | |||||
| raise ValueError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},' | |||||
| f' but got {get_typename(arg_type)}.') | |||||
| raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' | |||||
| f' but got {get_typename(arg_type)}.') | |||||
| class ParamValidator: | class ParamValidator: | ||||
| """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" | """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" | ||||
| @@ -103,6 +103,10 @@ class Cell: | |||||
| def parameter_layout_dict(self): | def parameter_layout_dict(self): | ||||
| return self._parameter_layout_dict | return self._parameter_layout_dict | ||||
| @property | |||||
| def cls_name(self): | |||||
| return self.__class__.__name__ | |||||
| @parameter_layout_dict.setter | @parameter_layout_dict.setter | ||||
| def parameter_layout_dict(self, value): | def parameter_layout_dict(self, value): | ||||
| if not isinstance(value, dict): | if not isinstance(value, dict): | ||||
| @@ -15,7 +15,7 @@ | |||||
| """dynamic learning rate""" | """dynamic learning rate""" | ||||
| import math | import math | ||||
| from mindspore._checkparam import ParamValidator as validator | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore._checkparam import Rel | from mindspore._checkparam import Rel | ||||
| @@ -43,16 +43,16 @@ def piecewise_constant_lr(milestone, learning_rates): | |||||
| >>> lr = piecewise_constant_lr(milestone, learning_rates) | >>> lr = piecewise_constant_lr(milestone, learning_rates) | ||||
| [0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01] | [0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01] | ||||
| """ | """ | ||||
| validator.check_type('milestone', milestone, (tuple, list)) | |||||
| validator.check_type('learning_rates', learning_rates, (tuple, list)) | |||||
| validator.check_value_type('milestone', milestone, (tuple, list), None) | |||||
| validator.check_value_type('learning_rates', learning_rates, (tuple, list), None) | |||||
| if len(milestone) != len(learning_rates): | if len(milestone) != len(learning_rates): | ||||
| raise ValueError('The size of `milestone` must be same with the size of `learning_rates`.') | raise ValueError('The size of `milestone` must be same with the size of `learning_rates`.') | ||||
| lr = [] | lr = [] | ||||
| last_item = 0 | last_item = 0 | ||||
| for i, item in enumerate(milestone): | for i, item in enumerate(milestone): | ||||
| validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT) | |||||
| validator.check_type(f'learning_rates[{i}]', learning_rates[i], [float]) | |||||
| validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT, None) | |||||
| validator.check_value_type(f'learning_rates[{i}]', learning_rates[i], [float], None) | |||||
| if item < last_item: | if item < last_item: | ||||
| raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]') | raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]') | ||||
| lr += [learning_rates[i]] * (item - last_item) | lr += [learning_rates[i]] * (item - last_item) | ||||
| @@ -62,12 +62,12 @@ def piecewise_constant_lr(milestone, learning_rates): | |||||
| def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair): | def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair): | ||||
| validator.check_integer('total_step', total_step, 0, Rel.GT) | |||||
| validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) | |||||
| validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) | |||||
| validator.check_float_positive('learning_rate', learning_rate) | |||||
| validator.check_float_positive('decay_rate', decay_rate) | |||||
| validator.check_type('is_stair', is_stair, [bool]) | |||||
| validator.check_integer('total_step', total_step, 0, Rel.GT, None) | |||||
| validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) | |||||
| validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) | |||||
| validator.check_float_positive('learning_rate', learning_rate, None) | |||||
| validator.check_float_positive('decay_rate', decay_rate, None) | |||||
| validator.check_value_type('is_stair', is_stair, [bool], None) | |||||
| def exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False): | def exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False): | ||||
| @@ -228,11 +228,11 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): | |||||
| >>> lr = cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch) | >>> lr = cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch) | ||||
| [0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01] | [0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01] | ||||
| """ | """ | ||||
| validator.check_float_positive('min_lr', min_lr) | |||||
| validator.check_float_positive('max_lr', max_lr) | |||||
| validator.check_integer('total_step', total_step, 0, Rel.GT) | |||||
| validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) | |||||
| validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) | |||||
| validator.check_float_positive('min_lr', min_lr, None) | |||||
| validator.check_float_positive('max_lr', max_lr, None) | |||||
| validator.check_integer('total_step', total_step, 0, Rel.GT, None) | |||||
| validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) | |||||
| validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) | |||||
| delta = 0.5 * (max_lr - min_lr) | delta = 0.5 * (max_lr - min_lr) | ||||
| lr = [] | lr = [] | ||||
| @@ -279,13 +279,13 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e | |||||
| >>> lr = polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power) | >>> lr = polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power) | ||||
| [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01] | [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01] | ||||
| """ | """ | ||||
| validator.check_float_positive('learning_rate', learning_rate) | |||||
| validator.check_float_positive('end_learning_rate', end_learning_rate) | |||||
| validator.check_integer('total_step', total_step, 0, Rel.GT) | |||||
| validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) | |||||
| validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) | |||||
| validator.check_type('power', power, [float]) | |||||
| validator.check_type('update_decay_epoch', update_decay_epoch, [bool]) | |||||
| validator.check_float_positive('learning_rate', learning_rate, None) | |||||
| validator.check_float_positive('end_learning_rate', end_learning_rate, None) | |||||
| validator.check_integer('total_step', total_step, 0, Rel.GT, None) | |||||
| validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) | |||||
| validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) | |||||
| validator.check_value_type('power', power, [float], None) | |||||
| validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None) | |||||
| function = lambda x, y: (x, min(x, y)) | function = lambda x, y: (x, min(x, y)) | ||||
| if update_decay_epoch: | if update_decay_epoch: | ||||
| @@ -25,7 +25,7 @@ from mindspore.common.parameter import Parameter | |||||
| from mindspore._extends import cell_attr_register | from mindspore._extends import cell_attr_register | ||||
| from ..cell import Cell | from ..cell import Cell | ||||
| from .activation import get_activation | from .activation import get_activation | ||||
| from ..._checkparam import ParamValidator as validator | |||||
| from ..._checkparam import Validator as validator | |||||
| class Dropout(Cell): | class Dropout(Cell): | ||||
| @@ -73,7 +73,7 @@ class Dropout(Cell): | |||||
| super(Dropout, self).__init__() | super(Dropout, self).__init__() | ||||
| if keep_prob <= 0 or keep_prob > 1: | if keep_prob <= 0 or keep_prob > 1: | ||||
| raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob)) | raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob)) | ||||
| validator.check_subclass("dtype", dtype, mstype.number_type) | |||||
| validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) | |||||
| self.keep_prob = Tensor(keep_prob) | self.keep_prob = Tensor(keep_prob) | ||||
| self.seed0 = seed0 | self.seed0 = seed0 | ||||
| self.seed1 = seed1 | self.seed1 = seed1 | ||||
| @@ -421,7 +421,7 @@ class Pad(Cell): | |||||
| super(Pad, self).__init__() | super(Pad, self).__init__() | ||||
| self.mode = mode | self.mode = mode | ||||
| self.paddings = paddings | self.paddings = paddings | ||||
| validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"]) | |||||
| validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], self.cls_name) | |||||
| if not isinstance(paddings, tuple): | if not isinstance(paddings, tuple): | ||||
| raise TypeError('Paddings must be tuple type.') | raise TypeError('Paddings must be tuple type.') | ||||
| for item in paddings: | for item in paddings: | ||||
| @@ -19,7 +19,7 @@ from mindspore.ops import operations as P | |||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from ..cell import Cell | from ..cell import Cell | ||||
| from ..._checkparam import ParamValidator as validator | |||||
| from ..._checkparam import Validator as validator | |||||
| class Embedding(Cell): | class Embedding(Cell): | ||||
| @@ -59,7 +59,7 @@ class Embedding(Cell): | |||||
| """ | """ | ||||
| def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32): | def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32): | ||||
| super(Embedding, self).__init__() | super(Embedding, self).__init__() | ||||
| validator.check_subclass("dtype", dtype, mstype.number_type) | |||||
| validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) | |||||
| self.vocab_size = vocab_size | self.vocab_size = vocab_size | ||||
| self.embedding_size = embedding_size | self.embedding_size = embedding_size | ||||
| self.use_one_hot = use_one_hot | self.use_one_hot = use_one_hot | ||||
| @@ -19,7 +19,7 @@ from mindspore.common.tensor import Tensor | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.ops.primitive import constexpr | from mindspore.ops.primitive import constexpr | ||||
| from mindspore._checkparam import ParamValidator as validator | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore._checkparam import Rel | from mindspore._checkparam import Rel | ||||
| from ..cell import Cell | from ..cell import Cell | ||||
| @@ -134,15 +134,15 @@ class SSIM(Cell): | |||||
| """ | """ | ||||
| def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): | def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): | ||||
| super(SSIM, self).__init__() | super(SSIM, self).__init__() | ||||
| validator.check_type('max_val', max_val, [int, float]) | |||||
| validator.check('max_val', max_val, '', 0.0, Rel.GT) | |||||
| validator.check_value_type('max_val', max_val, [int, float], self.cls_name) | |||||
| validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) | |||||
| self.max_val = max_val | self.max_val = max_val | ||||
| self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE) | |||||
| self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma) | |||||
| validator.check_type('k1', k1, [float]) | |||||
| self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER) | |||||
| validator.check_type('k2', k2, [float]) | |||||
| self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER) | |||||
| self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) | |||||
| self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name) | |||||
| validator.check_value_type('k1', k1, [float], self.cls_name) | |||||
| self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name) | |||||
| validator.check_value_type('k2', k2, [float], self.cls_name) | |||||
| self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name) | |||||
| self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size) | self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size) | ||||
| def construct(self, img1, img2): | def construct(self, img1, img2): | ||||
| @@ -231,8 +231,8 @@ class PSNR(Cell): | |||||
| """ | """ | ||||
| def __init__(self, max_val=1.0): | def __init__(self, max_val=1.0): | ||||
| super(PSNR, self).__init__() | super(PSNR, self).__init__() | ||||
| validator.check_type('max_val', max_val, [int, float]) | |||||
| validator.check('max_val', max_val, '', 0.0, Rel.GT) | |||||
| validator.check_value_type('max_val', max_val, [int, float], self.cls_name) | |||||
| validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) | |||||
| self.max_val = max_val | self.max_val = max_val | ||||
| def construct(self, img1, img2): | def construct(self, img1, img2): | ||||
| @@ -17,7 +17,7 @@ from mindspore.ops import operations as P | |||||
| from mindspore.nn.cell import Cell | from mindspore.nn.cell import Cell | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore._checkparam import ParamValidator as validator | |||||
| from mindspore._checkparam import Validator as validator | |||||
| class LSTM(Cell): | class LSTM(Cell): | ||||
| @@ -114,7 +114,7 @@ class LSTM(Cell): | |||||
| self.hidden_size = hidden_size | self.hidden_size = hidden_size | ||||
| self.num_layers = num_layers | self.num_layers = num_layers | ||||
| self.has_bias = has_bias | self.has_bias = has_bias | ||||
| self.batch_first = validator.check_type("batch_first", batch_first, [bool]) | |||||
| self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) | |||||
| self.dropout = float(dropout) | self.dropout = float(dropout) | ||||
| self.bidirectional = bidirectional | self.bidirectional = bidirectional | ||||
| @@ -14,8 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """pooling""" | """pooling""" | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore._checkparam import ParamValidator as validator | |||||
| from mindspore._checkparam import Rel | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from ... import context | from ... import context | ||||
| from ..cell import Cell | from ..cell import Cell | ||||
| @@ -24,35 +23,27 @@ class _PoolNd(Cell): | |||||
| """N-D AvgPool""" | """N-D AvgPool""" | ||||
| def __init__(self, kernel_size, stride, pad_mode): | def __init__(self, kernel_size, stride, pad_mode): | ||||
| name = self.__class__.__name__ | |||||
| super(_PoolNd, self).__init__() | super(_PoolNd, self).__init__() | ||||
| validator.check_type('kernel_size', kernel_size, [int, tuple]) | |||||
| validator.check_type('stride', stride, [int, tuple]) | |||||
| self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME']) | |||||
| if isinstance(kernel_size, int): | |||||
| validator.check_integer("kernel_size", kernel_size, 1, Rel.GE) | |||||
| else: | |||||
| if (len(kernel_size) != 2 or | |||||
| (not isinstance(kernel_size[0], int)) or | |||||
| (not isinstance(kernel_size[1], int)) or | |||||
| kernel_size[0] <= 0 or | |||||
| kernel_size[1] <= 0): | |||||
| raise ValueError(f'The kernel_size passed to cell {name} should be an positive int number or' | |||||
| f'a tuple of two positive int numbers, but got {kernel_size}') | |||||
| self.kernel_size = kernel_size | |||||
| if isinstance(stride, int): | |||||
| validator.check_integer("stride", stride, 1, Rel.GE) | |||||
| else: | |||||
| if (len(stride) != 2 or | |||||
| (not isinstance(stride[0], int)) or | |||||
| (not isinstance(stride[1], int)) or | |||||
| stride[0] <= 0 or | |||||
| stride[1] <= 0): | |||||
| raise ValueError(f'The stride passed to cell {name} should be an positive int number or' | |||||
| f'a tuple of two positive int numbers, but got {stride}') | |||||
| self.stride = stride | |||||
| self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name) | |||||
| def _check_int_or_tuple(arg_name, arg_value): | |||||
| validator.check_value_type(arg_name, arg_value, [int, tuple], self.cls_name) | |||||
| error_msg = f'For \'{self.cls_name}\' the {arg_name} should be an positive int number or ' \ | |||||
| f'a tuple of two positive int numbers, but got {arg_value}' | |||||
| if isinstance(arg_value, int): | |||||
| if arg_value <= 0: | |||||
| raise ValueError(error_msg) | |||||
| elif len(arg_value) == 2: | |||||
| for item in arg_value: | |||||
| if isinstance(item, int) and item > 0: | |||||
| continue | |||||
| raise ValueError(error_msg) | |||||
| else: | |||||
| raise ValueError(error_msg) | |||||
| return arg_value | |||||
| self.kernel_size = _check_int_or_tuple('kernel_size', kernel_size) | |||||
| self.stride = _check_int_or_tuple('stride', stride) | |||||
| def construct(self, *inputs): | def construct(self, *inputs): | ||||
| pass | pass | ||||
| @@ -15,7 +15,7 @@ | |||||
| """Fbeta.""" | """Fbeta.""" | ||||
| import sys | import sys | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore._checkparam import ParamValidator as validator | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from .metric import Metric | from .metric import Metric | ||||
| @@ -104,7 +104,7 @@ class Fbeta(Metric): | |||||
| Returns: | Returns: | ||||
| Float, computed result. | Float, computed result. | ||||
| """ | """ | ||||
| validator.check_type("average", average, [bool]) | |||||
| validator.check_value_type("average", average, [bool], self.__class__.__name__) | |||||
| if self._class_num == 0: | if self._class_num == 0: | ||||
| raise RuntimeError('Input number of samples can not be 0.') | raise RuntimeError('Input number of samples can not be 0.') | ||||
| @@ -17,7 +17,7 @@ import sys | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore._checkparam import ParamValidator as validator | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from .evaluation import EvaluationBase | from .evaluation import EvaluationBase | ||||
| @@ -136,7 +136,7 @@ class Precision(EvaluationBase): | |||||
| if self._class_num == 0: | if self._class_num == 0: | ||||
| raise RuntimeError('Input number of samples can not be 0.') | raise RuntimeError('Input number of samples can not be 0.') | ||||
| validator.check_type("average", average, [bool]) | |||||
| validator.check_value_type("average", average, [bool], self.__class__.__name__) | |||||
| result = self._true_positives / (self._positives + self.eps) | result = self._true_positives / (self._positives + self.eps) | ||||
| if average: | if average: | ||||
| @@ -17,7 +17,7 @@ import sys | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore._checkparam import ParamValidator as validator | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from .evaluation import EvaluationBase | from .evaluation import EvaluationBase | ||||
| @@ -136,7 +136,7 @@ class Recall(EvaluationBase): | |||||
| if self._class_num == 0: | if self._class_num == 0: | ||||
| raise RuntimeError('Input number of samples can not be 0.') | raise RuntimeError('Input number of samples can not be 0.') | ||||
| validator.check_type("average", average, [bool]) | |||||
| validator.check_value_type("average", average, [bool], self.__class__.__name__) | |||||
| result = self._true_positives / (self._actual_positives + self.eps) | result = self._true_positives / (self._actual_positives + self.eps) | ||||
| if average: | if average: | ||||
| @@ -22,7 +22,7 @@ from mindspore.ops import composite as C | |||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore._checkparam import ParamValidator as validator | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore._checkparam import Rel | from mindspore._checkparam import Rel | ||||
| from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
| @@ -78,16 +78,16 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad | |||||
| return next_v | return next_v | ||||
| def _check_param_value(beta1, beta2, eps, weight_decay): | |||||
| def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): | |||||
| """Check the type of inputs.""" | """Check the type of inputs.""" | ||||
| validator.check_type("beta1", beta1, [float]) | |||||
| validator.check_type("beta2", beta2, [float]) | |||||
| validator.check_type("eps", eps, [float]) | |||||
| validator.check_type("weight_dacay", weight_decay, [float]) | |||||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER) | |||||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER) | |||||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER) | |||||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT) | |||||
| validator.check_value_type("beta1", beta1, [float], prim_name) | |||||
| validator.check_value_type("beta2", beta2, [float], prim_name) | |||||
| validator.check_value_type("eps", eps, [float], prim_name) | |||||
| validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) | |||||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||||
| @adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", | @adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", | ||||
| @@ -168,11 +168,11 @@ class Adam(Optimizer): | |||||
| use_nesterov=False, weight_decay=0.0, loss_scale=1.0, | use_nesterov=False, weight_decay=0.0, loss_scale=1.0, | ||||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | ||||
| super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) | super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) | ||||
| _check_param_value(beta1, beta2, eps, weight_decay) | |||||
| validator.check_type("use_locking", use_locking, [bool]) | |||||
| validator.check_type("use_nesterov", use_nesterov, [bool]) | |||||
| validator.check_type("loss_scale", loss_scale, [float]) | |||||
| validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT) | |||||
| _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) | |||||
| validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) | |||||
| validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) | |||||
| validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) | |||||
| validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||||
| self.beta1 = Tensor(beta1, mstype.float32) | self.beta1 = Tensor(beta1, mstype.float32) | ||||
| self.beta2 = Tensor(beta2, mstype.float32) | self.beta2 = Tensor(beta2, mstype.float32) | ||||
| @@ -241,7 +241,7 @@ class AdamWeightDecay(Optimizer): | |||||
| """ | """ | ||||
| def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): | def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): | ||||
| super(AdamWeightDecay, self).__init__(learning_rate, params) | super(AdamWeightDecay, self).__init__(learning_rate, params) | ||||
| _check_param_value(beta1, beta2, eps, weight_decay) | |||||
| _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) | |||||
| self.lr = Tensor(np.array([learning_rate]).astype(np.float32)) | self.lr = Tensor(np.array([learning_rate]).astype(np.float32)) | ||||
| self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) | self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) | ||||
| self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) | self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) | ||||
| @@ -304,7 +304,7 @@ class AdamWeightDecayDynamicLR(Optimizer): | |||||
| eps=1e-6, | eps=1e-6, | ||||
| weight_decay=0.0): | weight_decay=0.0): | ||||
| super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params) | super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params) | ||||
| _check_param_value(beta1, beta2, eps, weight_decay) | |||||
| _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) | |||||
| # turn them to scalar when me support scalar/tensor mix operations | # turn them to scalar when me support scalar/tensor mix operations | ||||
| self.global_step = Parameter(initializer(0, [1]), name="global_step") | self.global_step = Parameter(initializer(0, [1]), name="global_step") | ||||
| @@ -18,7 +18,7 @@ from mindspore.common.initializer import initializer | |||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore.common import Tensor | from mindspore.common import Tensor | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore._checkparam import ParamValidator as validator | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore._checkparam import Rel | from mindspore._checkparam import Rel | ||||
| from .optimizer import Optimizer, apply_decay, grad_scale | from .optimizer import Optimizer, apply_decay, grad_scale | ||||
| @@ -30,29 +30,30 @@ def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weig | |||||
| success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power)) | success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power)) | ||||
| return success | return success | ||||
| def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0): | |||||
| validator.check_type("initial_accum", initial_accum, [float]) | |||||
| validator.check("initial_accum", initial_accum, "", 0.0, Rel.GE) | |||||
| def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0, | |||||
| prim_name=None): | |||||
| validator.check_value_type("initial_accum", initial_accum, [float], prim_name) | |||||
| validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name) | |||||
| validator.check_type("learning_rate", learning_rate, [float]) | |||||
| validator.check("learning_rate", learning_rate, "", 0.0, Rel.GT) | |||||
| validator.check_value_type("learning_rate", learning_rate, [float], prim_name) | |||||
| validator.check_number("learning_rate", learning_rate, 0.0, Rel.GT, prim_name) | |||||
| validator.check_type("lr_power", lr_power, [float]) | |||||
| validator.check("lr_power", lr_power, "", 0.0, Rel.LE) | |||||
| validator.check_value_type("lr_power", lr_power, [float], prim_name) | |||||
| validator.check_number("lr_power", lr_power, 0.0, Rel.LE, prim_name) | |||||
| validator.check_type("l1", l1, [float]) | |||||
| validator.check("l1", l1, "", 0.0, Rel.GE) | |||||
| validator.check_value_type("l1", l1, [float], prim_name) | |||||
| validator.check_number("l1", l1, 0.0, Rel.GE, prim_name) | |||||
| validator.check_type("l2", l2, [float]) | |||||
| validator.check("l2", l2, "", 0.0, Rel.GE) | |||||
| validator.check_value_type("l2", l2, [float], prim_name) | |||||
| validator.check_number("l2", l2, 0.0, Rel.GE, prim_name) | |||||
| validator.check_type("use_locking", use_locking, [bool]) | |||||
| validator.check_value_type("use_locking", use_locking, [bool], prim_name) | |||||
| validator.check_type("loss_scale", loss_scale, [float]) | |||||
| validator.check("loss_scale", loss_scale, "", 1.0, Rel.GE) | |||||
| validator.check_value_type("loss_scale", loss_scale, [float], prim_name) | |||||
| validator.check_number("loss_scale", loss_scale, 1.0, Rel.GE, prim_name) | |||||
| validator.check_type("weight_decay", weight_decay, [float]) | |||||
| validator.check("weight_decay", weight_decay, "", 0.0, Rel.GE) | |||||
| validator.check_value_type("weight_decay", weight_decay, [float], prim_name) | |||||
| validator.check_number("weight_decay", weight_decay, 0.0, Rel.GE, prim_name) | |||||
| class FTRL(Optimizer): | class FTRL(Optimizer): | ||||
| @@ -94,7 +95,8 @@ class FTRL(Optimizer): | |||||
| use_locking=False, loss_scale=1.0, weight_decay=0.0): | use_locking=False, loss_scale=1.0, weight_decay=0.0): | ||||
| super(FTRL, self).__init__(learning_rate, params) | super(FTRL, self).__init__(learning_rate, params) | ||||
| _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay) | |||||
| _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay, | |||||
| self.cls_name) | |||||
| self.moments = self.parameters.clone(prefix="moments", init=initial_accum) | self.moments = self.parameters.clone(prefix="moments", init=initial_accum) | ||||
| self.linear = self.parameters.clone(prefix="linear", init='zeros') | self.linear = self.parameters.clone(prefix="linear", init='zeros') | ||||
| self.l1 = l1 | self.l1 = l1 | ||||
| @@ -21,7 +21,7 @@ from mindspore.ops import composite as C | |||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore._checkparam import ParamValidator as validator | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore._checkparam import Rel | from mindspore._checkparam import Rel | ||||
| from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
| from .. import layer | from .. import layer | ||||
| @@ -109,23 +109,23 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para | |||||
| def _check_param_value(decay_steps, warmup_steps, start_learning_rate, | def _check_param_value(decay_steps, warmup_steps, start_learning_rate, | ||||
| end_learning_rate, power, beta1, beta2, eps, weight_decay): | |||||
| end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name): | |||||
| """Check the type of inputs.""" | """Check the type of inputs.""" | ||||
| validator.check_type("decay_steps", decay_steps, [int]) | |||||
| validator.check_type("warmup_steps", warmup_steps, [int]) | |||||
| validator.check_type("start_learning_rate", start_learning_rate, [float]) | |||||
| validator.check_type("end_learning_rate", end_learning_rate, [float]) | |||||
| validator.check_type("power", power, [float]) | |||||
| validator.check_type("beta1", beta1, [float]) | |||||
| validator.check_type("beta2", beta2, [float]) | |||||
| validator.check_type("eps", eps, [float]) | |||||
| validator.check_type("weight_dacay", weight_decay, [float]) | |||||
| validator.check_number_range("decay_steps", decay_steps, 1, float("inf"), Rel.INC_LEFT) | |||||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER) | |||||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER) | |||||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER) | |||||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT) | |||||
| validator.check_value_type("decay_steps", decay_steps, [int], prim_name) | |||||
| validator.check_value_type("warmup_steps", warmup_steps, [int], prim_name) | |||||
| validator.check_value_type("start_learning_rate", start_learning_rate, [float], prim_name) | |||||
| validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) | |||||
| validator.check_value_type("power", power, [float], prim_name) | |||||
| validator.check_value_type("beta1", beta1, [float], prim_name) | |||||
| validator.check_value_type("beta2", beta2, [float], prim_name) | |||||
| validator.check_value_type("eps", eps, [float], prim_name) | |||||
| validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) | |||||
| validator.check_number_range("decay_steps", decay_steps, 1, float("inf"), Rel.INC_LEFT, prim_name) | |||||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) | |||||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||||
| class Lamb(Optimizer): | class Lamb(Optimizer): | ||||
| @@ -182,7 +182,7 @@ class Lamb(Optimizer): | |||||
| super(Lamb, self).__init__(start_learning_rate, params) | super(Lamb, self).__init__(start_learning_rate, params) | ||||
| _check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate, | _check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate, | ||||
| power, beta1, beta2, eps, weight_decay) | |||||
| power, beta1, beta2, eps, weight_decay, self.cls_name) | |||||
| # turn them to scalar when me support scalar/tensor mix operations | # turn them to scalar when me support scalar/tensor mix operations | ||||
| self.global_step = Parameter(initializer(0, [1]), name="global_step") | self.global_step = Parameter(initializer(0, [1]), name="global_step") | ||||
| @@ -22,7 +22,7 @@ from mindspore.ops import functional as F, composite as C, operations as P | |||||
| from mindspore.nn.cell import Cell | from mindspore.nn.cell import Cell | ||||
| from mindspore.common.parameter import Parameter, ParameterTuple | from mindspore.common.parameter import Parameter, ParameterTuple | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore._checkparam import ParamValidator as validator | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore._checkparam import Rel | from mindspore._checkparam import Rel | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| @@ -63,7 +63,7 @@ class Optimizer(Cell): | |||||
| self.gather = None | self.gather = None | ||||
| self.assignadd = None | self.assignadd = None | ||||
| self.global_step = None | self.global_step = None | ||||
| validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT) | |||||
| validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||||
| else: | else: | ||||
| self.dynamic_lr = True | self.dynamic_lr = True | ||||
| self.gather = P.GatherV2() | self.gather = P.GatherV2() | ||||
| @@ -14,7 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """rmsprop""" | """rmsprop""" | ||||
| from mindspore.ops import functional as F, composite as C, operations as P | from mindspore.ops import functional as F, composite as C, operations as P | ||||
| from mindspore._checkparam import ParamValidator as validator | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
| rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | ||||
| @@ -144,8 +144,8 @@ class RMSProp(Optimizer): | |||||
| self.decay = decay | self.decay = decay | ||||
| self.epsilon = epsilon | self.epsilon = epsilon | ||||
| validator.check_type("use_locking", use_locking, [bool]) | |||||
| validator.check_type("centered", centered, [bool]) | |||||
| validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) | |||||
| validator.check_value_type("centered", centered, [bool], self.cls_name) | |||||
| self.centered = centered | self.centered = centered | ||||
| if centered: | if centered: | ||||
| self.opt = P.ApplyCenteredRMSProp(use_locking) | self.opt = P.ApplyCenteredRMSProp(use_locking) | ||||
| @@ -15,7 +15,7 @@ | |||||
| """sgd""" | """sgd""" | ||||
| from mindspore.ops import functional as F, composite as C, operations as P | from mindspore.ops import functional as F, composite as C, operations as P | ||||
| from mindspore.common.parameter import Parameter | from mindspore.common.parameter import Parameter | ||||
| from mindspore._checkparam import ParamValidator as validator | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
| sgd_opt = C.MultitypeFuncGraph("sgd_opt") | sgd_opt = C.MultitypeFuncGraph("sgd_opt") | ||||
| @@ -100,7 +100,7 @@ class SGD(Optimizer): | |||||
| raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening)) | raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening)) | ||||
| self.dampening = dampening | self.dampening = dampening | ||||
| validator.check_type("nesterov", nesterov, [bool]) | |||||
| validator.check_value_type("nesterov", nesterov, [bool], self.cls_name) | |||||
| self.nesterov = nesterov | self.nesterov = nesterov | ||||
| self.opt = P.SGD(dampening, weight_decay, nesterov) | self.opt = P.SGD(dampening, weight_decay, nesterov) | ||||
| @@ -19,7 +19,7 @@ import os | |||||
| import json | import json | ||||
| import inspect | import inspect | ||||
| from mindspore._c_expression import Oplib | from mindspore._c_expression import Oplib | ||||
| from mindspore._checkparam import ParamValidator as validator | |||||
| from mindspore._checkparam import Validator as validator | |||||
| # path of built-in op info register. | # path of built-in op info register. | ||||
| BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl" | BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl" | ||||
| @@ -43,7 +43,7 @@ def op_info_register(op_info): | |||||
| op_info_real = json.dumps(op_info) | op_info_real = json.dumps(op_info) | ||||
| else: | else: | ||||
| op_info_real = op_info | op_info_real = op_info | ||||
| validator.check_type("op_info", op_info_real, [str]) | |||||
| validator.check_value_type("op_info", op_info_real, [str], None) | |||||
| op_lib = Oplib() | op_lib = Oplib() | ||||
| file_path = os.path.realpath(inspect.getfile(func)) | file_path = os.path.realpath(inspect.getfile(func)) | ||||
| # keep the path custom ops implementation. | # keep the path custom ops implementation. | ||||
| @@ -16,7 +16,7 @@ | |||||
| from easydict import EasyDict as edict | from easydict import EasyDict as edict | ||||
| from .. import nn | from .. import nn | ||||
| from .._checkparam import ParamValidator as validator | |||||
| from .._checkparam import Validator as validator | |||||
| from .._checkparam import Rel | from .._checkparam import Rel | ||||
| from ..common import dtype as mstype | from ..common import dtype as mstype | ||||
| from ..nn.wrap.cell_wrapper import _VirtualDatasetCell | from ..nn.wrap.cell_wrapper import _VirtualDatasetCell | ||||
| @@ -73,14 +73,14 @@ def _check_kwargs(key_words): | |||||
| raise ValueError(f"Unsupported arg '{arg}'") | raise ValueError(f"Unsupported arg '{arg}'") | ||||
| if 'cast_model_type' in key_words: | if 'cast_model_type' in key_words: | ||||
| validator.check('cast_model_type', key_words['cast_model_type'], | |||||
| [mstype.float16, mstype.float32], Rel.IN) | |||||
| validator.check_type_name('cast_model_type', key_words['cast_model_type'], | |||||
| [mstype.float16, mstype.float32], None) | |||||
| if 'keep_batchnorm_fp32' in key_words: | if 'keep_batchnorm_fp32' in key_words: | ||||
| validator.check_isinstance('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool) | |||||
| validator.check_value_type('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool, None) | |||||
| if 'loss_scale_manager' in key_words: | if 'loss_scale_manager' in key_words: | ||||
| loss_scale_manager = key_words['loss_scale_manager'] | loss_scale_manager = key_words['loss_scale_manager'] | ||||
| if loss_scale_manager: | if loss_scale_manager: | ||||
| validator.check_isinstance('loss_scale_manager', loss_scale_manager, LossScaleManager) | |||||
| validator.check_value_type('loss_scale_manager', loss_scale_manager, LossScaleManager, None) | |||||
| def _add_loss_network(network, loss_fn, cast_model_type): | def _add_loss_network(network, loss_fn, cast_model_type): | ||||
| @@ -97,7 +97,7 @@ def _add_loss_network(network, loss_fn, cast_model_type): | |||||
| label = _mp_cast_helper(mstype.float32, label) | label = _mp_cast_helper(mstype.float32, label) | ||||
| return self._loss_fn(F.cast(out, mstype.float32), label) | return self._loss_fn(F.cast(out, mstype.float32), label) | ||||
| validator.check_isinstance('loss_fn', loss_fn, nn.Cell) | |||||
| validator.check_value_type('loss_fn', loss_fn, nn.Cell, None) | |||||
| if cast_model_type == mstype.float16: | if cast_model_type == mstype.float16: | ||||
| network = WithLossCell(network, loss_fn) | network = WithLossCell(network, loss_fn) | ||||
| else: | else: | ||||
| @@ -126,9 +126,9 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): | |||||
| loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else | loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else | ||||
| scale the loss by LossScaleManager. If set, overwrite the level setting. | scale the loss by LossScaleManager. If set, overwrite the level setting. | ||||
| """ | """ | ||||
| validator.check_isinstance('network', network, nn.Cell) | |||||
| validator.check_isinstance('optimizer', optimizer, nn.Optimizer) | |||||
| validator.check('level', level, "", ['O0', 'O2'], Rel.IN) | |||||
| validator.check_value_type('network', network, nn.Cell, None) | |||||
| validator.check_value_type('optimizer', optimizer, nn.Optimizer, None) | |||||
| validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None) | |||||
| _check_kwargs(kwargs) | _check_kwargs(kwargs) | ||||
| config = dict(_config_level[level], **kwargs) | config = dict(_config_level[level], **kwargs) | ||||
| config = edict(config) | config = edict(config) | ||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Loss scale manager abstract class.""" | """Loss scale manager abstract class.""" | ||||
| from .._checkparam import ParamValidator as validator | |||||
| from .._checkparam import Validator as validator | |||||
| from .._checkparam import Rel | from .._checkparam import Rel | ||||
| from .. import nn | from .. import nn | ||||
| @@ -97,7 +97,7 @@ class DynamicLossScaleManager(LossScaleManager): | |||||
| if init_loss_scale < 1.0: | if init_loss_scale < 1.0: | ||||
| raise ValueError("Loss scale value should be > 1") | raise ValueError("Loss scale value should be > 1") | ||||
| self.loss_scale = init_loss_scale | self.loss_scale = init_loss_scale | ||||
| validator.check_integer("scale_window", scale_window, 0, Rel.GT) | |||||
| validator.check_integer("scale_window", scale_window, 0, Rel.GT, self.__class__.__name__) | |||||
| self.scale_window = scale_window | self.scale_window = scale_window | ||||
| if scale_factor <= 0: | if scale_factor <= 0: | ||||
| raise ValueError("Scale factor should be > 1") | raise ValueError("Scale factor should be > 1") | ||||
| @@ -32,7 +32,7 @@ power = 0.5 | |||||
| class TestInputs: | class TestInputs: | ||||
| def test_milestone1(self): | def test_milestone1(self): | ||||
| milestone1 = 1 | milestone1 = 1 | ||||
| with pytest.raises(ValueError): | |||||
| with pytest.raises(TypeError): | |||||
| dr.piecewise_constant_lr(milestone1, learning_rates) | dr.piecewise_constant_lr(milestone1, learning_rates) | ||||
| def test_milestone2(self): | def test_milestone2(self): | ||||
| @@ -46,12 +46,12 @@ class TestInputs: | |||||
| def test_learning_rates1(self): | def test_learning_rates1(self): | ||||
| lr = True | lr = True | ||||
| with pytest.raises(ValueError): | |||||
| with pytest.raises(TypeError): | |||||
| dr.piecewise_constant_lr(milestone, lr) | dr.piecewise_constant_lr(milestone, lr) | ||||
| def test_learning_rates2(self): | def test_learning_rates2(self): | ||||
| lr = [1, 2, 1] | lr = [1, 2, 1] | ||||
| with pytest.raises(ValueError): | |||||
| with pytest.raises(TypeError): | |||||
| dr.piecewise_constant_lr(milestone, lr) | dr.piecewise_constant_lr(milestone, lr) | ||||
| def test_learning_rate_type(self): | def test_learning_rate_type(self): | ||||
| @@ -158,7 +158,7 @@ class TestInputs: | |||||
| def test_is_stair(self): | def test_is_stair(self): | ||||
| is_stair = 1 | is_stair = 1 | ||||
| with pytest.raises(ValueError): | |||||
| with pytest.raises(TypeError): | |||||
| dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair) | dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair) | ||||
| def test_min_lr_type(self): | def test_min_lr_type(self): | ||||
| @@ -183,12 +183,12 @@ class TestInputs: | |||||
| def test_power(self): | def test_power(self): | ||||
| power1 = True | power1 = True | ||||
| with pytest.raises(ValueError): | |||||
| with pytest.raises(TypeError): | |||||
| dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power1) | dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power1) | ||||
| def test_update_decay_epoch(self): | def test_update_decay_epoch(self): | ||||
| update_decay_epoch = 1 | update_decay_epoch = 1 | ||||
| with pytest.raises(ValueError): | |||||
| with pytest.raises(TypeError): | |||||
| dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, | dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, | ||||
| power, update_decay_epoch) | power, update_decay_epoch) | ||||
| @@ -52,7 +52,7 @@ def test_psnr_max_val_negative(): | |||||
| def test_psnr_max_val_bool(): | def test_psnr_max_val_bool(): | ||||
| max_val = True | max_val = True | ||||
| with pytest.raises(ValueError): | |||||
| with pytest.raises(TypeError): | |||||
| net = PSNRNet(max_val) | net = PSNRNet(max_val) | ||||
| def test_psnr_max_val_zero(): | def test_psnr_max_val_zero(): | ||||
| @@ -51,7 +51,7 @@ def test_ssim_max_val_negative(): | |||||
| def test_ssim_max_val_bool(): | def test_ssim_max_val_bool(): | ||||
| max_val = True | max_val = True | ||||
| with pytest.raises(ValueError): | |||||
| with pytest.raises(TypeError): | |||||
| net = SSIMNet(max_val) | net = SSIMNet(max_val) | ||||
| def test_ssim_max_val_zero(): | def test_ssim_max_val_zero(): | ||||
| @@ -92,4 +92,4 @@ def test_ssim_k1_k2_wrong_value(): | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| net = SSIMNet(k2=0.0) | net = SSIMNet(k2=0.0) | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| net = SSIMNet(k2=-1.0) | |||||
| net = SSIMNet(k2=-1.0) | |||||
| @@ -577,14 +577,14 @@ test_cases_for_verify_exception = [ | |||||
| ('MaxPool2d_ValueError_2', { | ('MaxPool2d_ValueError_2', { | ||||
| 'block': ( | 'block': ( | ||||
| lambda _: nn.MaxPool2d(kernel_size=120, stride=True, pad_mode="valid"), | lambda _: nn.MaxPool2d(kernel_size=120, stride=True, pad_mode="valid"), | ||||
| {'exception': ValueError}, | |||||
| {'exception': TypeError}, | |||||
| ), | ), | ||||
| 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], | 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], | ||||
| }), | }), | ||||
| ('MaxPool2d_ValueError_3', { | ('MaxPool2d_ValueError_3', { | ||||
| 'block': ( | 'block': ( | ||||
| lambda _: nn.MaxPool2d(kernel_size=3, stride=True, pad_mode="valid"), | lambda _: nn.MaxPool2d(kernel_size=3, stride=True, pad_mode="valid"), | ||||
| {'exception': ValueError}, | |||||
| {'exception': TypeError}, | |||||
| ), | ), | ||||
| 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], | 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], | ||||
| }), | }), | ||||
| @@ -38,7 +38,7 @@ def test_avgpool2d_error_input(): | |||||
| """ test_avgpool2d_error_input """ | """ test_avgpool2d_error_input """ | ||||
| kernel_size = 5 | kernel_size = 5 | ||||
| stride = 2.3 | stride = 2.3 | ||||
| with pytest.raises(ValueError): | |||||
| with pytest.raises(TypeError): | |||||
| nn.AvgPool2d(kernel_size, stride) | nn.AvgPool2d(kernel_size, stride) | ||||