| @@ -17,7 +17,7 @@ import re | |||
| from enum import Enum | |||
| from functools import reduce | |||
| from itertools import repeat | |||
| from collections import Iterable | |||
| from collections.abc import Iterable | |||
| import numpy as np | |||
| from mindspore import log as logger | |||
| @@ -98,7 +98,7 @@ class Validator: | |||
| """validator for checking input parameters""" | |||
| @staticmethod | |||
| def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None): | |||
| def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError): | |||
| """ | |||
| Method for judging relation between two int values or list/tuple made up of ints. | |||
| @@ -108,8 +108,8 @@ class Validator: | |||
| rel_fn = Rel.get_fns(rel) | |||
| if not rel_fn(arg_value, value): | |||
| rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') | |||
| msg_prefix = f'For {prim_name} the' if prim_name else "The" | |||
| raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') | |||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | |||
| raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') | |||
| @staticmethod | |||
| def check_integer(arg_name, arg_value, value, rel, prim_name): | |||
| @@ -118,8 +118,17 @@ class Validator: | |||
| type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) | |||
| if type_mismatch or not rel_fn(arg_value, value): | |||
| rel_str = Rel.get_strs(rel).format(value) | |||
| raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},' | |||
| f' but got {arg_value}.') | |||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | |||
| raise ValueError(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_number(arg_name, arg_value, value, rel, prim_name): | |||
| """Integer value judgment.""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| if not rel_fn(arg_value, value): | |||
| rel_str = Rel.get_strs(rel).format(value) | |||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| @@ -133,9 +142,46 @@ class Validator: | |||
| f' but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): | |||
| """Method for checking whether a numeric value is in some range.""" | |||
| rel_fn = Rel.get_fns(rel) | |||
| if not rel_fn(arg_value, lower_limit, upper_limit): | |||
| rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) | |||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.') | |||
| return arg_value | |||
| @staticmethod | |||
| def check_string(arg_name, arg_value, valid_values, prim_name): | |||
| """Checks whether a string is in some value list""" | |||
| if isinstance(arg_value, str) and arg_value in valid_values: | |||
| return arg_value | |||
| if len(valid_values) == 1: | |||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be {valid_values[0]},' | |||
| f' but got {arg_value}.') | |||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be one of {valid_values},' | |||
| f' but got {arg_value}.') | |||
| @staticmethod | |||
| def check_pad_value_by_mode(pad_mode, padding, prim_name): | |||
| """Validates value of padding according to pad_mode""" | |||
| if pad_mode != 'pad' and padding != 0: | |||
| raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.") | |||
| return padding | |||
| @staticmethod | |||
| def check_float_positive(arg_name, arg_value, prim_name): | |||
| """Float type judgment.""" | |||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" | |||
| if isinstance(arg_value, float): | |||
| if arg_value > 0: | |||
| return arg_value | |||
| raise ValueError(f"{msg_prefix} `{arg_name}` must be positive, but got {arg_value}.") | |||
| raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") | |||
| @staticmethod | |||
| def check_subclass(arg_name, type_, template_type, prim_name): | |||
| """Check whether some type is sublcass of another type""" | |||
| """Checks whether some type is sublcass of another type""" | |||
| if not isinstance(template_type, Iterable): | |||
| template_type = (template_type,) | |||
| if not any([mstype.issubclass_(type_, x) for x in template_type]): | |||
| @@ -143,16 +189,44 @@ class Validator: | |||
| raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' | |||
| f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') | |||
| @staticmethod | |||
| def check_const_input(arg_name, arg_value, prim_name): | |||
| """Check valid value.""" | |||
| if arg_value is None: | |||
| raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.') | |||
| @staticmethod | |||
| def check_scalar_type_same(args, valid_values, prim_name): | |||
| """check whether the types of inputs are the same.""" | |||
| def _check_tensor_type(arg): | |||
| arg_key, arg_val = arg | |||
| elem_type = arg_val | |||
| if not elem_type in valid_values: | |||
| raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {valid_values},' | |||
| f' but `{arg_key}` is {elem_type}.') | |||
| return (arg_key, elem_type) | |||
| def _check_types_same(arg1, arg2): | |||
| arg1_name, arg1_type = arg1 | |||
| arg2_name, arg2_type = arg2 | |||
| if arg1_type != arg2_type: | |||
| raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' | |||
| f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') | |||
| return arg1 | |||
| elem_types = map(_check_tensor_type, args.items()) | |||
| reduce(_check_types_same, elem_types) | |||
| @staticmethod | |||
| def check_tensor_type_same(args, valid_values, prim_name): | |||
| """check whether the element types of input tensors are the same.""" | |||
| """Checks whether the element types of input tensors are the same.""" | |||
| def _check_tensor_type(arg): | |||
| arg_key, arg_val = arg | |||
| Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name) | |||
| elem_type = arg_val.element_type() | |||
| if not elem_type in valid_values: | |||
| raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},' | |||
| f' but `{arg_key}` is {elem_type}.') | |||
| f' but element type of `{arg_key}` is {elem_type}.') | |||
| return (arg_key, elem_type) | |||
| def _check_types_same(arg1, arg2): | |||
| @@ -168,8 +242,13 @@ class Validator: | |||
| @staticmethod | |||
| def check_scalar_or_tensor_type_same(args, valid_values, prim_name): | |||
| """check whether the types of inputs are the same. if the input args are tensors, check their element types""" | |||
| def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): | |||
| """ | |||
| Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. | |||
| If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. | |||
| """ | |||
| def _check_argument_type(arg): | |||
| arg_key, arg_val = arg | |||
| if isinstance(arg_val, type(mstype.tensor)): | |||
| @@ -188,6 +267,9 @@ class Validator: | |||
| arg2_type = arg2_type.element_type() | |||
| elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))): | |||
| pass | |||
| elif allow_mix: | |||
| arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type | |||
| arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type | |||
| else: | |||
| excp_flag = True | |||
| @@ -199,13 +281,14 @@ class Validator: | |||
| @staticmethod | |||
| def check_value_type(arg_name, arg_value, valid_types, prim_name): | |||
| """Check whether a values is instance of some types.""" | |||
| """Checks whether a value is instance of some types.""" | |||
| valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) | |||
| def raise_error_msg(): | |||
| """func for raising error message when check failed""" | |||
| type_names = [t.__name__ for t in valid_types] | |||
| num_types = len(valid_types) | |||
| raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be ' | |||
| f'{"one of " if num_types > 1 else ""}' | |||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' | |||
| raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' | |||
| f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') | |||
| # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and | |||
| @@ -216,6 +299,23 @@ class Validator: | |||
| return arg_value | |||
| raise_error_msg() | |||
| @staticmethod | |||
| def check_type_name(arg_name, arg_type, valid_types, prim_name): | |||
| """Checks whether a type in some specified types""" | |||
| valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) | |||
| def get_typename(t): | |||
| return t.__name__ if hasattr(t, '__name__') else str(t) | |||
| if arg_type in valid_types: | |||
| return arg_type | |||
| type_names = [get_typename(t) for t in valid_types] | |||
| msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' | |||
| if len(valid_types) == 1: | |||
| raise ValueError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},' | |||
| f' but got {get_typename(arg_type)}.') | |||
| raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' | |||
| f' but got {get_typename(arg_type)}.') | |||
| class ParamValidator: | |||
| """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" | |||
| @@ -103,6 +103,10 @@ class Cell: | |||
| def parameter_layout_dict(self): | |||
| return self._parameter_layout_dict | |||
| @property | |||
| def cls_name(self): | |||
| return self.__class__.__name__ | |||
| @parameter_layout_dict.setter | |||
| def parameter_layout_dict(self, value): | |||
| if not isinstance(value, dict): | |||
| @@ -15,7 +15,7 @@ | |||
| """dynamic learning rate""" | |||
| import math | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| @@ -43,16 +43,16 @@ def piecewise_constant_lr(milestone, learning_rates): | |||
| >>> lr = piecewise_constant_lr(milestone, learning_rates) | |||
| [0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01] | |||
| """ | |||
| validator.check_type('milestone', milestone, (tuple, list)) | |||
| validator.check_type('learning_rates', learning_rates, (tuple, list)) | |||
| validator.check_value_type('milestone', milestone, (tuple, list), None) | |||
| validator.check_value_type('learning_rates', learning_rates, (tuple, list), None) | |||
| if len(milestone) != len(learning_rates): | |||
| raise ValueError('The size of `milestone` must be same with the size of `learning_rates`.') | |||
| lr = [] | |||
| last_item = 0 | |||
| for i, item in enumerate(milestone): | |||
| validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT) | |||
| validator.check_type(f'learning_rates[{i}]', learning_rates[i], [float]) | |||
| validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT, None) | |||
| validator.check_value_type(f'learning_rates[{i}]', learning_rates[i], [float], None) | |||
| if item < last_item: | |||
| raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]') | |||
| lr += [learning_rates[i]] * (item - last_item) | |||
| @@ -62,12 +62,12 @@ def piecewise_constant_lr(milestone, learning_rates): | |||
| def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair): | |||
| validator.check_integer('total_step', total_step, 0, Rel.GT) | |||
| validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) | |||
| validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) | |||
| validator.check_float_positive('learning_rate', learning_rate) | |||
| validator.check_float_positive('decay_rate', decay_rate) | |||
| validator.check_type('is_stair', is_stair, [bool]) | |||
| validator.check_integer('total_step', total_step, 0, Rel.GT, None) | |||
| validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) | |||
| validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) | |||
| validator.check_float_positive('learning_rate', learning_rate, None) | |||
| validator.check_float_positive('decay_rate', decay_rate, None) | |||
| validator.check_value_type('is_stair', is_stair, [bool], None) | |||
| def exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False): | |||
| @@ -228,11 +228,11 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): | |||
| >>> lr = cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch) | |||
| [0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01] | |||
| """ | |||
| validator.check_float_positive('min_lr', min_lr) | |||
| validator.check_float_positive('max_lr', max_lr) | |||
| validator.check_integer('total_step', total_step, 0, Rel.GT) | |||
| validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) | |||
| validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) | |||
| validator.check_float_positive('min_lr', min_lr, None) | |||
| validator.check_float_positive('max_lr', max_lr, None) | |||
| validator.check_integer('total_step', total_step, 0, Rel.GT, None) | |||
| validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) | |||
| validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) | |||
| delta = 0.5 * (max_lr - min_lr) | |||
| lr = [] | |||
| @@ -279,13 +279,13 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e | |||
| >>> lr = polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power) | |||
| [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01] | |||
| """ | |||
| validator.check_float_positive('learning_rate', learning_rate) | |||
| validator.check_float_positive('end_learning_rate', end_learning_rate) | |||
| validator.check_integer('total_step', total_step, 0, Rel.GT) | |||
| validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) | |||
| validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) | |||
| validator.check_type('power', power, [float]) | |||
| validator.check_type('update_decay_epoch', update_decay_epoch, [bool]) | |||
| validator.check_float_positive('learning_rate', learning_rate, None) | |||
| validator.check_float_positive('end_learning_rate', end_learning_rate, None) | |||
| validator.check_integer('total_step', total_step, 0, Rel.GT, None) | |||
| validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) | |||
| validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) | |||
| validator.check_value_type('power', power, [float], None) | |||
| validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None) | |||
| function = lambda x, y: (x, min(x, y)) | |||
| if update_decay_epoch: | |||
| @@ -25,7 +25,7 @@ from mindspore.common.parameter import Parameter | |||
| from mindspore._extends import cell_attr_register | |||
| from ..cell import Cell | |||
| from .activation import get_activation | |||
| from ..._checkparam import ParamValidator as validator | |||
| from ..._checkparam import Validator as validator | |||
| class Dropout(Cell): | |||
| @@ -73,7 +73,7 @@ class Dropout(Cell): | |||
| super(Dropout, self).__init__() | |||
| if keep_prob <= 0 or keep_prob > 1: | |||
| raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob)) | |||
| validator.check_subclass("dtype", dtype, mstype.number_type) | |||
| validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) | |||
| self.keep_prob = Tensor(keep_prob) | |||
| self.seed0 = seed0 | |||
| self.seed1 = seed1 | |||
| @@ -421,7 +421,7 @@ class Pad(Cell): | |||
| super(Pad, self).__init__() | |||
| self.mode = mode | |||
| self.paddings = paddings | |||
| validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"]) | |||
| validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], self.cls_name) | |||
| if not isinstance(paddings, tuple): | |||
| raise TypeError('Paddings must be tuple type.') | |||
| for item in paddings: | |||
| @@ -19,7 +19,7 @@ from mindspore.ops import operations as P | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from ..cell import Cell | |||
| from ..._checkparam import ParamValidator as validator | |||
| from ..._checkparam import Validator as validator | |||
| class Embedding(Cell): | |||
| @@ -59,7 +59,7 @@ class Embedding(Cell): | |||
| """ | |||
| def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32): | |||
| super(Embedding, self).__init__() | |||
| validator.check_subclass("dtype", dtype, mstype.number_type) | |||
| validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) | |||
| self.vocab_size = vocab_size | |||
| self.embedding_size = embedding_size | |||
| self.use_one_hot = use_one_hot | |||
| @@ -19,7 +19,7 @@ from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops.primitive import constexpr | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from ..cell import Cell | |||
| @@ -134,15 +134,15 @@ class SSIM(Cell): | |||
| """ | |||
| def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): | |||
| super(SSIM, self).__init__() | |||
| validator.check_type('max_val', max_val, [int, float]) | |||
| validator.check('max_val', max_val, '', 0.0, Rel.GT) | |||
| validator.check_value_type('max_val', max_val, [int, float], self.cls_name) | |||
| validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) | |||
| self.max_val = max_val | |||
| self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE) | |||
| self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma) | |||
| validator.check_type('k1', k1, [float]) | |||
| self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER) | |||
| validator.check_type('k2', k2, [float]) | |||
| self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER) | |||
| self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) | |||
| self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name) | |||
| validator.check_value_type('k1', k1, [float], self.cls_name) | |||
| self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name) | |||
| validator.check_value_type('k2', k2, [float], self.cls_name) | |||
| self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name) | |||
| self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size) | |||
| def construct(self, img1, img2): | |||
| @@ -231,8 +231,8 @@ class PSNR(Cell): | |||
| """ | |||
| def __init__(self, max_val=1.0): | |||
| super(PSNR, self).__init__() | |||
| validator.check_type('max_val', max_val, [int, float]) | |||
| validator.check('max_val', max_val, '', 0.0, Rel.GT) | |||
| validator.check_value_type('max_val', max_val, [int, float], self.cls_name) | |||
| validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) | |||
| self.max_val = max_val | |||
| def construct(self, img1, img2): | |||
| @@ -17,7 +17,7 @@ from mindspore.ops import operations as P | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| from mindspore._checkparam import Validator as validator | |||
| class LSTM(Cell): | |||
| @@ -114,7 +114,7 @@ class LSTM(Cell): | |||
| self.hidden_size = hidden_size | |||
| self.num_layers = num_layers | |||
| self.has_bias = has_bias | |||
| self.batch_first = validator.check_type("batch_first", batch_first, [bool]) | |||
| self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) | |||
| self.dropout = float(dropout) | |||
| self.bidirectional = bidirectional | |||
| @@ -14,8 +14,7 @@ | |||
| # ============================================================================ | |||
| """pooling""" | |||
| from mindspore.ops import operations as P | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| from mindspore._checkparam import Rel | |||
| from mindspore._checkparam import Validator as validator | |||
| from ... import context | |||
| from ..cell import Cell | |||
| @@ -24,35 +23,27 @@ class _PoolNd(Cell): | |||
| """N-D AvgPool""" | |||
| def __init__(self, kernel_size, stride, pad_mode): | |||
| name = self.__class__.__name__ | |||
| super(_PoolNd, self).__init__() | |||
| validator.check_type('kernel_size', kernel_size, [int, tuple]) | |||
| validator.check_type('stride', stride, [int, tuple]) | |||
| self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME']) | |||
| if isinstance(kernel_size, int): | |||
| validator.check_integer("kernel_size", kernel_size, 1, Rel.GE) | |||
| else: | |||
| if (len(kernel_size) != 2 or | |||
| (not isinstance(kernel_size[0], int)) or | |||
| (not isinstance(kernel_size[1], int)) or | |||
| kernel_size[0] <= 0 or | |||
| kernel_size[1] <= 0): | |||
| raise ValueError(f'The kernel_size passed to cell {name} should be an positive int number or' | |||
| f'a tuple of two positive int numbers, but got {kernel_size}') | |||
| self.kernel_size = kernel_size | |||
| if isinstance(stride, int): | |||
| validator.check_integer("stride", stride, 1, Rel.GE) | |||
| else: | |||
| if (len(stride) != 2 or | |||
| (not isinstance(stride[0], int)) or | |||
| (not isinstance(stride[1], int)) or | |||
| stride[0] <= 0 or | |||
| stride[1] <= 0): | |||
| raise ValueError(f'The stride passed to cell {name} should be an positive int number or' | |||
| f'a tuple of two positive int numbers, but got {stride}') | |||
| self.stride = stride | |||
| self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name) | |||
| def _check_int_or_tuple(arg_name, arg_value): | |||
| validator.check_value_type(arg_name, arg_value, [int, tuple], self.cls_name) | |||
| error_msg = f'For \'{self.cls_name}\' the {arg_name} should be an positive int number or ' \ | |||
| f'a tuple of two positive int numbers, but got {arg_value}' | |||
| if isinstance(arg_value, int): | |||
| if arg_value <= 0: | |||
| raise ValueError(error_msg) | |||
| elif len(arg_value) == 2: | |||
| for item in arg_value: | |||
| if isinstance(item, int) and item > 0: | |||
| continue | |||
| raise ValueError(error_msg) | |||
| else: | |||
| raise ValueError(error_msg) | |||
| return arg_value | |||
| self.kernel_size = _check_int_or_tuple('kernel_size', kernel_size) | |||
| self.stride = _check_int_or_tuple('stride', stride) | |||
| def construct(self, *inputs): | |||
| pass | |||
| @@ -15,7 +15,7 @@ | |||
| """Fbeta.""" | |||
| import sys | |||
| import numpy as np | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| from mindspore._checkparam import Validator as validator | |||
| from .metric import Metric | |||
| @@ -104,7 +104,7 @@ class Fbeta(Metric): | |||
| Returns: | |||
| Float, computed result. | |||
| """ | |||
| validator.check_type("average", average, [bool]) | |||
| validator.check_value_type("average", average, [bool], self.__class__.__name__) | |||
| if self._class_num == 0: | |||
| raise RuntimeError('Input number of samples can not be 0.') | |||
| @@ -17,7 +17,7 @@ import sys | |||
| import numpy as np | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| from mindspore._checkparam import Validator as validator | |||
| from .evaluation import EvaluationBase | |||
| @@ -136,7 +136,7 @@ class Precision(EvaluationBase): | |||
| if self._class_num == 0: | |||
| raise RuntimeError('Input number of samples can not be 0.') | |||
| validator.check_type("average", average, [bool]) | |||
| validator.check_value_type("average", average, [bool], self.__class__.__name__) | |||
| result = self._true_positives / (self._positives + self.eps) | |||
| if average: | |||
| @@ -17,7 +17,7 @@ import sys | |||
| import numpy as np | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| from mindspore._checkparam import Validator as validator | |||
| from .evaluation import EvaluationBase | |||
| @@ -136,7 +136,7 @@ class Recall(EvaluationBase): | |||
| if self._class_num == 0: | |||
| raise RuntimeError('Input number of samples can not be 0.') | |||
| validator.check_type("average", average, [bool]) | |||
| validator.check_value_type("average", average, [bool], self.__class__.__name__) | |||
| result = self._true_positives / (self._actual_positives + self.eps) | |||
| if average: | |||
| @@ -22,7 +22,7 @@ from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from .optimizer import Optimizer | |||
| @@ -78,16 +78,16 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad | |||
| return next_v | |||
| def _check_param_value(beta1, beta2, eps, weight_decay): | |||
| def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): | |||
| """Check the type of inputs.""" | |||
| validator.check_type("beta1", beta1, [float]) | |||
| validator.check_type("beta2", beta2, [float]) | |||
| validator.check_type("eps", eps, [float]) | |||
| validator.check_type("weight_dacay", weight_decay, [float]) | |||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER) | |||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER) | |||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER) | |||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT) | |||
| validator.check_value_type("beta1", beta1, [float], prim_name) | |||
| validator.check_value_type("beta2", beta2, [float], prim_name) | |||
| validator.check_value_type("eps", eps, [float], prim_name) | |||
| validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) | |||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||
| @adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", | |||
| @@ -168,11 +168,11 @@ class Adam(Optimizer): | |||
| use_nesterov=False, weight_decay=0.0, loss_scale=1.0, | |||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||
| super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) | |||
| _check_param_value(beta1, beta2, eps, weight_decay) | |||
| validator.check_type("use_locking", use_locking, [bool]) | |||
| validator.check_type("use_nesterov", use_nesterov, [bool]) | |||
| validator.check_type("loss_scale", loss_scale, [float]) | |||
| validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT) | |||
| _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) | |||
| validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) | |||
| validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) | |||
| validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) | |||
| validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||
| self.beta1 = Tensor(beta1, mstype.float32) | |||
| self.beta2 = Tensor(beta2, mstype.float32) | |||
| @@ -241,7 +241,7 @@ class AdamWeightDecay(Optimizer): | |||
| """ | |||
| def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): | |||
| super(AdamWeightDecay, self).__init__(learning_rate, params) | |||
| _check_param_value(beta1, beta2, eps, weight_decay) | |||
| _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) | |||
| self.lr = Tensor(np.array([learning_rate]).astype(np.float32)) | |||
| self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) | |||
| self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) | |||
| @@ -304,7 +304,7 @@ class AdamWeightDecayDynamicLR(Optimizer): | |||
| eps=1e-6, | |||
| weight_decay=0.0): | |||
| super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params) | |||
| _check_param_value(beta1, beta2, eps, weight_decay) | |||
| _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) | |||
| # turn them to scalar when me support scalar/tensor mix operations | |||
| self.global_step = Parameter(initializer(0, [1]), name="global_step") | |||
| @@ -18,7 +18,7 @@ from mindspore.common.initializer import initializer | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from .optimizer import Optimizer, apply_decay, grad_scale | |||
| @@ -30,29 +30,30 @@ def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weig | |||
| success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power)) | |||
| return success | |||
| def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0): | |||
| validator.check_type("initial_accum", initial_accum, [float]) | |||
| validator.check("initial_accum", initial_accum, "", 0.0, Rel.GE) | |||
| def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0, | |||
| prim_name=None): | |||
| validator.check_value_type("initial_accum", initial_accum, [float], prim_name) | |||
| validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name) | |||
| validator.check_type("learning_rate", learning_rate, [float]) | |||
| validator.check("learning_rate", learning_rate, "", 0.0, Rel.GT) | |||
| validator.check_value_type("learning_rate", learning_rate, [float], prim_name) | |||
| validator.check_number("learning_rate", learning_rate, 0.0, Rel.GT, prim_name) | |||
| validator.check_type("lr_power", lr_power, [float]) | |||
| validator.check("lr_power", lr_power, "", 0.0, Rel.LE) | |||
| validator.check_value_type("lr_power", lr_power, [float], prim_name) | |||
| validator.check_number("lr_power", lr_power, 0.0, Rel.LE, prim_name) | |||
| validator.check_type("l1", l1, [float]) | |||
| validator.check("l1", l1, "", 0.0, Rel.GE) | |||
| validator.check_value_type("l1", l1, [float], prim_name) | |||
| validator.check_number("l1", l1, 0.0, Rel.GE, prim_name) | |||
| validator.check_type("l2", l2, [float]) | |||
| validator.check("l2", l2, "", 0.0, Rel.GE) | |||
| validator.check_value_type("l2", l2, [float], prim_name) | |||
| validator.check_number("l2", l2, 0.0, Rel.GE, prim_name) | |||
| validator.check_type("use_locking", use_locking, [bool]) | |||
| validator.check_value_type("use_locking", use_locking, [bool], prim_name) | |||
| validator.check_type("loss_scale", loss_scale, [float]) | |||
| validator.check("loss_scale", loss_scale, "", 1.0, Rel.GE) | |||
| validator.check_value_type("loss_scale", loss_scale, [float], prim_name) | |||
| validator.check_number("loss_scale", loss_scale, 1.0, Rel.GE, prim_name) | |||
| validator.check_type("weight_decay", weight_decay, [float]) | |||
| validator.check("weight_decay", weight_decay, "", 0.0, Rel.GE) | |||
| validator.check_value_type("weight_decay", weight_decay, [float], prim_name) | |||
| validator.check_number("weight_decay", weight_decay, 0.0, Rel.GE, prim_name) | |||
| class FTRL(Optimizer): | |||
| @@ -94,7 +95,8 @@ class FTRL(Optimizer): | |||
| use_locking=False, loss_scale=1.0, weight_decay=0.0): | |||
| super(FTRL, self).__init__(learning_rate, params) | |||
| _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay) | |||
| _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay, | |||
| self.cls_name) | |||
| self.moments = self.parameters.clone(prefix="moments", init=initial_accum) | |||
| self.linear = self.parameters.clone(prefix="linear", init='zeros') | |||
| self.l1 = l1 | |||
| @@ -21,7 +21,7 @@ from mindspore.ops import composite as C | |||
| from mindspore.ops import functional as F | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from .optimizer import Optimizer | |||
| from .. import layer | |||
| @@ -109,23 +109,23 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para | |||
| def _check_param_value(decay_steps, warmup_steps, start_learning_rate, | |||
| end_learning_rate, power, beta1, beta2, eps, weight_decay): | |||
| end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name): | |||
| """Check the type of inputs.""" | |||
| validator.check_type("decay_steps", decay_steps, [int]) | |||
| validator.check_type("warmup_steps", warmup_steps, [int]) | |||
| validator.check_type("start_learning_rate", start_learning_rate, [float]) | |||
| validator.check_type("end_learning_rate", end_learning_rate, [float]) | |||
| validator.check_type("power", power, [float]) | |||
| validator.check_type("beta1", beta1, [float]) | |||
| validator.check_type("beta2", beta2, [float]) | |||
| validator.check_type("eps", eps, [float]) | |||
| validator.check_type("weight_dacay", weight_decay, [float]) | |||
| validator.check_number_range("decay_steps", decay_steps, 1, float("inf"), Rel.INC_LEFT) | |||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER) | |||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER) | |||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER) | |||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT) | |||
| validator.check_value_type("decay_steps", decay_steps, [int], prim_name) | |||
| validator.check_value_type("warmup_steps", warmup_steps, [int], prim_name) | |||
| validator.check_value_type("start_learning_rate", start_learning_rate, [float], prim_name) | |||
| validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) | |||
| validator.check_value_type("power", power, [float], prim_name) | |||
| validator.check_value_type("beta1", beta1, [float], prim_name) | |||
| validator.check_value_type("beta2", beta2, [float], prim_name) | |||
| validator.check_value_type("eps", eps, [float], prim_name) | |||
| validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) | |||
| validator.check_number_range("decay_steps", decay_steps, 1, float("inf"), Rel.INC_LEFT, prim_name) | |||
| validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) | |||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) | |||
| class Lamb(Optimizer): | |||
| @@ -182,7 +182,7 @@ class Lamb(Optimizer): | |||
| super(Lamb, self).__init__(start_learning_rate, params) | |||
| _check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate, | |||
| power, beta1, beta2, eps, weight_decay) | |||
| power, beta1, beta2, eps, weight_decay, self.cls_name) | |||
| # turn them to scalar when me support scalar/tensor mix operations | |||
| self.global_step = Parameter(initializer(0, [1]), name="global_step") | |||
| @@ -22,7 +22,7 @@ from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.common.parameter import Parameter, ParameterTuple | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore._checkparam import Rel | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore import log as logger | |||
| @@ -63,7 +63,7 @@ class Optimizer(Cell): | |||
| self.gather = None | |||
| self.assignadd = None | |||
| self.global_step = None | |||
| validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT) | |||
| validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||
| else: | |||
| self.dynamic_lr = True | |||
| self.gather = P.GatherV2() | |||
| @@ -14,7 +14,7 @@ | |||
| # ============================================================================ | |||
| """rmsprop""" | |||
| from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| from mindspore._checkparam import Validator as validator | |||
| from .optimizer import Optimizer | |||
| rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | |||
| @@ -144,8 +144,8 @@ class RMSProp(Optimizer): | |||
| self.decay = decay | |||
| self.epsilon = epsilon | |||
| validator.check_type("use_locking", use_locking, [bool]) | |||
| validator.check_type("centered", centered, [bool]) | |||
| validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) | |||
| validator.check_value_type("centered", centered, [bool], self.cls_name) | |||
| self.centered = centered | |||
| if centered: | |||
| self.opt = P.ApplyCenteredRMSProp(use_locking) | |||
| @@ -15,7 +15,7 @@ | |||
| """sgd""" | |||
| from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| from mindspore._checkparam import Validator as validator | |||
| from .optimizer import Optimizer | |||
| sgd_opt = C.MultitypeFuncGraph("sgd_opt") | |||
| @@ -100,7 +100,7 @@ class SGD(Optimizer): | |||
| raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening)) | |||
| self.dampening = dampening | |||
| validator.check_type("nesterov", nesterov, [bool]) | |||
| validator.check_value_type("nesterov", nesterov, [bool], self.cls_name) | |||
| self.nesterov = nesterov | |||
| self.opt = P.SGD(dampening, weight_decay, nesterov) | |||
| @@ -19,7 +19,7 @@ import os | |||
| import json | |||
| import inspect | |||
| from mindspore._c_expression import Oplib | |||
| from mindspore._checkparam import ParamValidator as validator | |||
| from mindspore._checkparam import Validator as validator | |||
| # path of built-in op info register. | |||
| BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl" | |||
| @@ -43,7 +43,7 @@ def op_info_register(op_info): | |||
| op_info_real = json.dumps(op_info) | |||
| else: | |||
| op_info_real = op_info | |||
| validator.check_type("op_info", op_info_real, [str]) | |||
| validator.check_value_type("op_info", op_info_real, [str], None) | |||
| op_lib = Oplib() | |||
| file_path = os.path.realpath(inspect.getfile(func)) | |||
| # keep the path custom ops implementation. | |||
| @@ -16,7 +16,7 @@ | |||
| from easydict import EasyDict as edict | |||
| from .. import nn | |||
| from .._checkparam import ParamValidator as validator | |||
| from .._checkparam import Validator as validator | |||
| from .._checkparam import Rel | |||
| from ..common import dtype as mstype | |||
| from ..nn.wrap.cell_wrapper import _VirtualDatasetCell | |||
| @@ -73,14 +73,14 @@ def _check_kwargs(key_words): | |||
| raise ValueError(f"Unsupported arg '{arg}'") | |||
| if 'cast_model_type' in key_words: | |||
| validator.check('cast_model_type', key_words['cast_model_type'], | |||
| [mstype.float16, mstype.float32], Rel.IN) | |||
| validator.check_type_name('cast_model_type', key_words['cast_model_type'], | |||
| [mstype.float16, mstype.float32], None) | |||
| if 'keep_batchnorm_fp32' in key_words: | |||
| validator.check_isinstance('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool) | |||
| validator.check_value_type('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool, None) | |||
| if 'loss_scale_manager' in key_words: | |||
| loss_scale_manager = key_words['loss_scale_manager'] | |||
| if loss_scale_manager: | |||
| validator.check_isinstance('loss_scale_manager', loss_scale_manager, LossScaleManager) | |||
| validator.check_value_type('loss_scale_manager', loss_scale_manager, LossScaleManager, None) | |||
| def _add_loss_network(network, loss_fn, cast_model_type): | |||
| @@ -97,7 +97,7 @@ def _add_loss_network(network, loss_fn, cast_model_type): | |||
| label = _mp_cast_helper(mstype.float32, label) | |||
| return self._loss_fn(F.cast(out, mstype.float32), label) | |||
| validator.check_isinstance('loss_fn', loss_fn, nn.Cell) | |||
| validator.check_value_type('loss_fn', loss_fn, nn.Cell, None) | |||
| if cast_model_type == mstype.float16: | |||
| network = WithLossCell(network, loss_fn) | |||
| else: | |||
| @@ -126,9 +126,9 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): | |||
| loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else | |||
| scale the loss by LossScaleManager. If set, overwrite the level setting. | |||
| """ | |||
| validator.check_isinstance('network', network, nn.Cell) | |||
| validator.check_isinstance('optimizer', optimizer, nn.Optimizer) | |||
| validator.check('level', level, "", ['O0', 'O2'], Rel.IN) | |||
| validator.check_value_type('network', network, nn.Cell, None) | |||
| validator.check_value_type('optimizer', optimizer, nn.Optimizer, None) | |||
| validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None) | |||
| _check_kwargs(kwargs) | |||
| config = dict(_config_level[level], **kwargs) | |||
| config = edict(config) | |||
| @@ -13,7 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Loss scale manager abstract class.""" | |||
| from .._checkparam import ParamValidator as validator | |||
| from .._checkparam import Validator as validator | |||
| from .._checkparam import Rel | |||
| from .. import nn | |||
| @@ -97,7 +97,7 @@ class DynamicLossScaleManager(LossScaleManager): | |||
| if init_loss_scale < 1.0: | |||
| raise ValueError("Loss scale value should be > 1") | |||
| self.loss_scale = init_loss_scale | |||
| validator.check_integer("scale_window", scale_window, 0, Rel.GT) | |||
| validator.check_integer("scale_window", scale_window, 0, Rel.GT, self.__class__.__name__) | |||
| self.scale_window = scale_window | |||
| if scale_factor <= 0: | |||
| raise ValueError("Scale factor should be > 1") | |||
| @@ -32,7 +32,7 @@ power = 0.5 | |||
| class TestInputs: | |||
| def test_milestone1(self): | |||
| milestone1 = 1 | |||
| with pytest.raises(ValueError): | |||
| with pytest.raises(TypeError): | |||
| dr.piecewise_constant_lr(milestone1, learning_rates) | |||
| def test_milestone2(self): | |||
| @@ -46,12 +46,12 @@ class TestInputs: | |||
| def test_learning_rates1(self): | |||
| lr = True | |||
| with pytest.raises(ValueError): | |||
| with pytest.raises(TypeError): | |||
| dr.piecewise_constant_lr(milestone, lr) | |||
| def test_learning_rates2(self): | |||
| lr = [1, 2, 1] | |||
| with pytest.raises(ValueError): | |||
| with pytest.raises(TypeError): | |||
| dr.piecewise_constant_lr(milestone, lr) | |||
| def test_learning_rate_type(self): | |||
| @@ -158,7 +158,7 @@ class TestInputs: | |||
| def test_is_stair(self): | |||
| is_stair = 1 | |||
| with pytest.raises(ValueError): | |||
| with pytest.raises(TypeError): | |||
| dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair) | |||
| def test_min_lr_type(self): | |||
| @@ -183,12 +183,12 @@ class TestInputs: | |||
| def test_power(self): | |||
| power1 = True | |||
| with pytest.raises(ValueError): | |||
| with pytest.raises(TypeError): | |||
| dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power1) | |||
| def test_update_decay_epoch(self): | |||
| update_decay_epoch = 1 | |||
| with pytest.raises(ValueError): | |||
| with pytest.raises(TypeError): | |||
| dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, | |||
| power, update_decay_epoch) | |||
| @@ -52,7 +52,7 @@ def test_psnr_max_val_negative(): | |||
| def test_psnr_max_val_bool(): | |||
| max_val = True | |||
| with pytest.raises(ValueError): | |||
| with pytest.raises(TypeError): | |||
| net = PSNRNet(max_val) | |||
| def test_psnr_max_val_zero(): | |||
| @@ -51,7 +51,7 @@ def test_ssim_max_val_negative(): | |||
| def test_ssim_max_val_bool(): | |||
| max_val = True | |||
| with pytest.raises(ValueError): | |||
| with pytest.raises(TypeError): | |||
| net = SSIMNet(max_val) | |||
| def test_ssim_max_val_zero(): | |||
| @@ -92,4 +92,4 @@ def test_ssim_k1_k2_wrong_value(): | |||
| with pytest.raises(ValueError): | |||
| net = SSIMNet(k2=0.0) | |||
| with pytest.raises(ValueError): | |||
| net = SSIMNet(k2=-1.0) | |||
| net = SSIMNet(k2=-1.0) | |||
| @@ -577,14 +577,14 @@ test_cases_for_verify_exception = [ | |||
| ('MaxPool2d_ValueError_2', { | |||
| 'block': ( | |||
| lambda _: nn.MaxPool2d(kernel_size=120, stride=True, pad_mode="valid"), | |||
| {'exception': ValueError}, | |||
| {'exception': TypeError}, | |||
| ), | |||
| 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], | |||
| }), | |||
| ('MaxPool2d_ValueError_3', { | |||
| 'block': ( | |||
| lambda _: nn.MaxPool2d(kernel_size=3, stride=True, pad_mode="valid"), | |||
| {'exception': ValueError}, | |||
| {'exception': TypeError}, | |||
| ), | |||
| 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], | |||
| }), | |||
| @@ -38,7 +38,7 @@ def test_avgpool2d_error_input(): | |||
| """ test_avgpool2d_error_input """ | |||
| kernel_size = 5 | |||
| stride = 2.3 | |||
| with pytest.raises(ValueError): | |||
| with pytest.raises(TypeError): | |||
| nn.AvgPool2d(kernel_size, stride) | |||