Merge pull request !637 from ghzl/learning-rate-make-group-modetags/v0.3.0-alpha
| @@ -103,9 +103,9 @@ def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, po | |||
| validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) | |||
| @adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", | |||
| @adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor") | |||
| def _run_opt_with_one_number(opt, lr, beta1_power, beta2_power, beta1, beta2, eps, gradient, params, moment1, | |||
| def _run_opt_with_one_number(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, | |||
| moment2): | |||
| """Apply adam optimizer to the weight parameter using Tensor.""" | |||
| success = True | |||
| @@ -136,9 +136,27 @@ class Adam(Optimizer): | |||
| `beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`, | |||
| :math:`\epsilon` represents `eps`. | |||
| Note: | |||
| The Adam optimizer supports separating parameter groups. Different parameter groups can set different | |||
| `learning_rate` and `weight_decay`. | |||
| When separating parameter groups, the weight decay in each group will be applied on the parameters if the | |||
| value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be | |||
| applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. | |||
| Args: | |||
| params (list[Parameter]): A list of parameter, which will be updated. The element in `params` | |||
| should be class mindspore.Parameter. | |||
| params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, | |||
| the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", | |||
| "lr" and "weight_decay" are the keys can be parsed. | |||
| - params: Required. The value should be a list of `Parameter`. | |||
| - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. | |||
| If not, the `learning_rate` in the API will be used. | |||
| - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay | |||
| will be used. If not, the `weight_decay` in the API will be used. | |||
| learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is | |||
| Iterable or a Tensor and the dims of the Tensor is 1, | |||
| use dynamic learning rate, then the i-th step will | |||
| @@ -161,8 +179,6 @@ class Adam(Optimizer): | |||
| weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | |||
| loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default: | |||
| 1.0. | |||
| decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: | |||
| lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name. | |||
| Inputs: | |||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||
| @@ -172,15 +188,26 @@ class Adam(Optimizer): | |||
| Examples: | |||
| >>> net = Net() | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| >>> #1) All parameters use the same learning rate and weight decay | |||
| >>> optim = nn.Adam(params=net.trainable_params()) | |||
| >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | |||
| >>> | |||
| >>> #2) Use parameter groups and set different values | |||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01}, | |||
| >>> {'params': no_conv_params}] | |||
| >>> opt = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0) | |||
| >>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01 | |||
| >>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a | |||
| >>> # learning rate of 0.1 and a weight decay of 0.0. | |||
| >>> | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| >>> model = Model(net, loss_fn=loss, optimizer=optim) | |||
| """ | |||
| def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, | |||
| use_nesterov=False, weight_decay=0.0, loss_scale=1.0, | |||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||
| super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) | |||
| use_nesterov=False, weight_decay=0.0, loss_scale=1.0): | |||
| super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale) | |||
| _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) | |||
| validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) | |||
| validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) | |||
| @@ -216,10 +243,14 @@ class Adam(Optimizer): | |||
| self.beta1_power = beta1_power | |||
| beta2_power = self.beta2_power * self.beta2 | |||
| self.beta2_power = beta2_power | |||
| success = self.hyper_map(F.partial(adam_opt, self.opt, lr, beta1_power, beta2_power, self.beta1, | |||
| self.beta2, self.eps), | |||
| gradients, params, moment1, moment2) | |||
| if self.is_group: | |||
| success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1, | |||
| self.beta2, self.eps), | |||
| lr, gradients, params, moment1, moment2) | |||
| else: | |||
| success = self.hyper_map(F.partial(adam_opt, self.opt, beta1_power, beta2_power, self.beta1, | |||
| self.beta2, self.eps, lr), | |||
| gradients, params, moment1, moment2) | |||
| return success | |||
| @@ -262,6 +293,8 @@ class AdamWeightDecay(Optimizer): | |||
| def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0, | |||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||
| super(AdamWeightDecay, self).__init__(learning_rate, params) | |||
| if self.is_group: | |||
| raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") | |||
| _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) | |||
| self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) | |||
| self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) | |||
| @@ -330,6 +363,8 @@ class AdamWeightDecayDynamicLR(Optimizer): | |||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name, | |||
| warmup_steps=0): | |||
| super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params) | |||
| if self.is_group: | |||
| raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") | |||
| _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) | |||
| _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, self.cls_name) | |||
| # turn them to scalar when me support scalar/tensor mix operations | |||
| @@ -96,7 +96,8 @@ class FTRL(Optimizer): | |||
| def __init__(self, params, initial_accum=0.1, learning_rate=0.001, lr_power=-0.5, l1=0.0, l2=0.0, | |||
| use_locking=False, loss_scale=1.0, weight_decay=0.0): | |||
| super(FTRL, self).__init__(learning_rate, params) | |||
| if self.is_group: | |||
| raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") | |||
| _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay, | |||
| self.cls_name) | |||
| self.moments = self.parameters.clone(prefix="moments", init=initial_accum) | |||
| @@ -183,6 +183,8 @@ class Lamb(Optimizer): | |||
| decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name): | |||
| super(Lamb, self).__init__(start_learning_rate, params) | |||
| if self.is_group: | |||
| raise RuntimeError(f"The {self.cls_name} optimizer cannot support group setting.") | |||
| _check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate, | |||
| power, beta1, beta2, eps, weight_decay, self.cls_name) | |||
| @@ -23,7 +23,7 @@ momentum_opt = C.MultitypeFuncGraph("momentum_opt") | |||
| @momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") | |||
| def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment): | |||
| def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment): | |||
| """Apply momentum optimizer to the weight parameter using Tensor.""" | |||
| success = True | |||
| success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum)) | |||
| @@ -36,9 +36,27 @@ class Momentum(Optimizer): | |||
| Refer to the paper on the importance of initialization and momentum in deep learning for more details. | |||
| Note: | |||
| The Momentum optimizer supports separating parameter groups. Different parameter groups can set different | |||
| `learning_rate` and `weight_decay`. | |||
| When separating parameter groups, the weight decay in each group will be applied on the parameters if the | |||
| value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be | |||
| applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. | |||
| Args: | |||
| params (list[Parameter]): A list of parameter, which will be updated. The element in `parameters` | |||
| should be class mindspore.Parameter. | |||
| params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, | |||
| the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", | |||
| "lr" and "weight_decay" are the keys can be parsed. | |||
| - params: Required. The value should be a list of `Parameter`. | |||
| - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. | |||
| If not, the `learning_rate` in the API will be used. | |||
| - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay | |||
| will be used. If not, the `weight_decay` in the API will be used. | |||
| learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is | |||
| Iterable or a Tensor and the dims of the Tensor is 1, | |||
| use dynamic learning rate, then the i-th step will | |||
| @@ -49,8 +67,6 @@ class Momentum(Optimizer): | |||
| momentum (float): Hyperparameter of type float, means momentum for the moving average. | |||
| weight_decay (float): Weight decay (L2 penalty). Default: 0.0. | |||
| loss_scale (float): A floating point value for the loss scale. Default: 1.0. | |||
| decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: | |||
| lambda x: 'beta' not in x.name and 'gamma' not in x.name. | |||
| Inputs: | |||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||
| @@ -63,13 +79,24 @@ class Momentum(Optimizer): | |||
| Examples: | |||
| >>> net = Net() | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| >>> #1) All parameters use the same learning rate and weight decay | |||
| >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| >>> | |||
| >>> #2) Use parameter groups and set different values | |||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01}, | |||
| >>> {'params': no_conv_params}] | |||
| >>> opt = nn.Momentum(group_params, learning_rate=0.1, momentum=0.9, weight_decay=0.0) | |||
| >>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01 | |||
| >>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a | |||
| >>> # learning rate of 0.1 and a weight decay of 0.0. | |||
| >>> | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | |||
| """ | |||
| def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, | |||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||
| super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) | |||
| def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0): | |||
| super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale) | |||
| if isinstance(momentum, float) and momentum < 0.0: | |||
| raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) | |||
| self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") | |||
| @@ -84,5 +111,8 @@ class Momentum(Optimizer): | |||
| gradients = self.decay_weight(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| lr = self.get_lr() | |||
| success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments) | |||
| if self.is_group: | |||
| success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum), lr, gradients, params, moments) | |||
| else: | |||
| success = self.hyper_map(F.partial(momentum_opt, self.opt, self.momentum, lr), gradients, params, moments) | |||
| return success | |||
| @@ -28,7 +28,6 @@ from mindspore._checkparam import Rel | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore import log as logger | |||
| __all__ = ['Optimizer'] | |||
| @@ -42,68 +41,96 @@ class Optimizer(Cell): | |||
| This class defines the API to add Ops to train a model. Never use | |||
| this class directly, but instead instantiate one of its subclasses. | |||
| Some optimizers support separating parameter groups. Different parameter groups can set different | |||
| `learning_rate` and `weight_decay`. | |||
| When separating parameter groups, the weight decay in each group will be applied on the parameters if the | |||
| value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be | |||
| applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. | |||
| Args: | |||
| learning_rate (float): A floating point value for the learning rate. Should be greater than 0. | |||
| parameters (list): A list of parameter, which will be updated. The element in `parameters` | |||
| should be class mindspore.Parameter. | |||
| parameters (Union[list[Parameter], list[dict]]): When the `parameters` is a list of `Parameter` which will be | |||
| updated, the element in `parameters` should be class `Parameter`. When the `parameters` is a list of `dict`, | |||
| the "params", "lr" and "weight_decay" are the keys can be parsed. | |||
| - params: Required. The value should be a list of `Parameter`. | |||
| - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. | |||
| If not, the `learning_rate` in the API will be used. | |||
| - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay | |||
| will be used. If not, the `weight_decay` in the API will be used. | |||
| weight_decay (float): A floating point value for the weight decay. It should be equal to or greater than 0. | |||
| If the type of `weight_decay` input is int, it will be convertd to float. Default: 0.0. | |||
| If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0. | |||
| loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the | |||
| type of `loss_scale` input is int, it will be convertd to float. Default: 1.0. | |||
| decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda | |||
| x: 'beta' not in x.name and 'gamma' not in x.name. | |||
| type of `loss_scale` input is int, it will be converted to float. Default: 1.0. | |||
| Raises: | |||
| ValueError: If the learning_rate is a Tensor, but the dims of tensor is greater than 1. | |||
| TypeError: If the learning_rate is not any of the three types: float, Tensor, Iterable. | |||
| """ | |||
| def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0, | |||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||
| def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0): | |||
| super(Optimizer, self).__init__(auto_prefix=False) | |||
| if parameters and not isinstance(parameters, list): | |||
| parameters = list(parameters) | |||
| if not parameters: | |||
| raise ValueError("Optimizer got an empty parameter list.") | |||
| if not isinstance(parameters[0], (dict, Parameter)): | |||
| raise ValueError("Only a list of Parameter or dict can be supported.") | |||
| if isinstance(loss_scale, int): | |||
| loss_scale = float(loss_scale) | |||
| validator.check_value_type("loss_scale", loss_scale, [float], None) | |||
| validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, None) | |||
| if isinstance(weight_decay, int): | |||
| weight_decay = float(weight_decay) | |||
| validator.check_value_type("weight_decay", weight_decay, [float], None) | |||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, None) | |||
| self.is_group = False | |||
| self.loss_scale = loss_scale | |||
| if isinstance(learning_rate, float): | |||
| self.dynamic_lr = False | |||
| self.gather = None | |||
| self.assignadd = None | |||
| self.global_step = None | |||
| validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||
| learning_rate = Tensor(learning_rate, mstype.float32) | |||
| self.scalar_lr = learning_rate | |||
| else: | |||
| self.dynamic_lr = True | |||
| self.gather = P.GatherV2() | |||
| self.assignadd = P.AssignAdd() | |||
| self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') | |||
| if isinstance(learning_rate, Iterable): | |||
| learning_rate = Tensor(np.array(list(learning_rate)).astype(np.float32)) | |||
| elif isinstance(learning_rate, Tensor): | |||
| if learning_rate.dim() > 1: | |||
| raise ValueError("Learning rate should be a 0 or 1 dim `Tensor`," | |||
| f"but got {learning_rate.dim()}.") | |||
| if learning_rate.dim() == 1 and learning_rate.size() < 2: | |||
| logger.warning("If want to use the dynamic learning rate, please make sure that the number " | |||
| "of elements in the list, tuple or tensor passed is greater than 1.") | |||
| else: | |||
| raise TypeError("Learning rate should be float, Tensor or Iterable.") | |||
| if isinstance(weight_decay, int): | |||
| weight_decay = float(weight_decay) | |||
| validator.check_value_type("weight_decay", weight_decay, [float], None) | |||
| validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, None) | |||
| if isinstance(loss_scale, int): | |||
| loss_scale = float(loss_scale) | |||
| validator.check_value_type("loss_scale", loss_scale, [float], None) | |||
| validator.check_number_range("loss_scale", loss_scale, 0.0, float("inf"), Rel.INC_NEITHER, None) | |||
| self.loss_scale = loss_scale | |||
| self.learning_rate = Parameter(learning_rate, name="learning_rate") | |||
| self.parameters = ParameterTuple(parameters) | |||
| self.scalar_lr = None | |||
| learning_rate = self._get_single_lr(learning_rate) | |||
| if isinstance(parameters[0], dict): | |||
| self.is_group = True | |||
| self.params = [] | |||
| self.group_lr = [] | |||
| self.group_weight_decay = [] | |||
| self._init_group_params(parameters, learning_rate, weight_decay) | |||
| if self.is_group: | |||
| self.learning_rate = ParameterTuple(self.group_lr) | |||
| self.parameters = ParameterTuple(self.params) | |||
| self.weight_decay = tuple(self.group_weight_decay) | |||
| decay_filter = lambda x: x > 0 | |||
| self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) | |||
| else: | |||
| self.learning_rate = Parameter(learning_rate, name="learning_rate") | |||
| self.parameters = ParameterTuple(parameters) | |||
| self.weight_decay = weight_decay * loss_scale | |||
| decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name | |||
| self.decay_flags = tuple(decay_filter(x) for x in self.parameters) | |||
| self.reciprocal_scale = 1.0 / loss_scale | |||
| self.weight_decay = weight_decay * loss_scale | |||
| self.decay_flags = tuple(decay_filter(x) for x in self.parameters) | |||
| if not self.parameters: | |||
| raise ValueError("optimizer got an empty parameter list.") | |||
| self.exec_weight_decay = any(self.decay_flags) | |||
| self.param_length = len(self.parameters) | |||
| def decay_weight(self, gradients): | |||
| """ | |||
| @@ -118,9 +145,15 @@ class Optimizer(Cell): | |||
| Returns: | |||
| tuple[Tensor], The gradients after weight decay. | |||
| """ | |||
| if self.weight_decay > 0: | |||
| params = self.parameters | |||
| gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients) | |||
| params = self.parameters | |||
| if self.is_group: | |||
| if self.exec_weight_decay: | |||
| gradients = self.hyper_map(F.partial(apply_decay), self.weight_decay, self.decay_flags, | |||
| params, gradients) | |||
| else: | |||
| if self.weight_decay > 0: | |||
| gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, | |||
| params, gradients) | |||
| return gradients | |||
| @@ -144,6 +177,83 @@ class Optimizer(Cell): | |||
| return gradients | |||
| def _get_single_lr(self, learning_rate): | |||
| """Get learning rate in Tensor type.""" | |||
| if isinstance(learning_rate, float): | |||
| validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||
| lr = Tensor(learning_rate, mstype.float32) | |||
| elif isinstance(learning_rate, Iterable): | |||
| lr = Tensor(np.array(list(learning_rate)).astype(np.float32)) | |||
| elif isinstance(learning_rate, Tensor): | |||
| if learning_rate.dim() > 1: | |||
| raise ValueError("Learning rate should be a 0 or 1 dim `Tensor`," | |||
| f"but got {learning_rate.dim()}.") | |||
| if learning_rate.dim() == 1 and learning_rate.size() < 2: | |||
| logger.warning("If want to use the dynamic learning rate, please make sure that the number " | |||
| "of elements in the list, tuple or tensor passed is greater than 1.") | |||
| lr = learning_rate | |||
| else: | |||
| raise TypeError("Learning rate should be float, Tensor or Iterable.") | |||
| return lr | |||
| def _init_group_params(self, parameters, learning_rate, weight_decay): | |||
| """Init learning rate or weight decay in group params.""" | |||
| origin_dynamic_lr = self.dynamic_lr | |||
| if self.dynamic_lr: | |||
| dynamic_lr_length = learning_rate.size() | |||
| else: | |||
| dynamic_lr_length = 0 | |||
| for group_param in parameters: | |||
| lr_length = dynamic_lr_length | |||
| if 'lr' in group_param.keys(): | |||
| self._get_single_lr(group_param['lr']) | |||
| if isinstance(group_param['lr'], Iterable): | |||
| lr_length = len(group_param['lr']) | |||
| self.dynamic_lr = True | |||
| elif isinstance(group_param['lr'], Tensor): | |||
| lr_length = group_param['lr'].size() | |||
| self.dynamic_lr = True | |||
| if dynamic_lr_length not in (lr_length, 0): | |||
| raise ValueError("The dynamic learning rate in group should be the same size.") | |||
| dynamic_lr_length = lr_length | |||
| if self.dynamic_lr and not origin_dynamic_lr: | |||
| self.gather = P.GatherV2() | |||
| self.assignadd = P.AssignAdd() | |||
| self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') | |||
| params_store = [] | |||
| for group_param in parameters: | |||
| self.params += group_param['params'] | |||
| if 'lr' in group_param.keys(): | |||
| params_dynamic_lr = isinstance(group_param['lr'], (Iterable, Tensor)) | |||
| if self.dynamic_lr and not params_dynamic_lr: | |||
| lr = Tensor(np.array([group_param['lr']] * dynamic_lr_length).astype(np.float32)) | |||
| else: | |||
| lr = self._get_single_lr(group_param['lr']) | |||
| else: | |||
| if self.dynamic_lr and not origin_dynamic_lr: | |||
| lr = Tensor(np.array([self.scalar_lr] * dynamic_lr_length).astype(np.float32)) | |||
| else: | |||
| lr = learning_rate | |||
| if 'weight_decay' in group_param.keys(): | |||
| validator.check_float_legal_value('weight_decay', group_param['weight_decay'], None) | |||
| validator.check_number_range('weight_decay', group_param['weight_decay'], 0.0, float("inf"), | |||
| Rel.INC_LEFT, self.cls_name) | |||
| weight_decay_ = group_param['weight_decay'] * self.loss_scale | |||
| else: | |||
| weight_decay_ = weight_decay * self.loss_scale | |||
| for param in group_param['params']: | |||
| if param in params_store: | |||
| raise RuntimeError(f"The {param.name} parameter has appeared in parameter groups.") | |||
| params_store.append(param) | |||
| self.group_lr.append(Parameter(lr, name="lr_" + param.name)) | |||
| self.group_weight_decay.append(weight_decay_) | |||
| def get_lr(self): | |||
| """ | |||
| Get the learning rate of current step. | |||
| @@ -151,11 +261,20 @@ class Optimizer(Cell): | |||
| Returns: | |||
| float, the learning rate of current step. | |||
| """ | |||
| lr = self.learning_rate | |||
| if self.dynamic_lr: | |||
| lr = self.gather(self.learning_rate, self.global_step, 0) | |||
| F.control_depend(lr, self.assignadd(self.global_step, 1)) | |||
| if self.is_group: | |||
| lr = self.learning_rate | |||
| if self.dynamic_lr: | |||
| lr = () | |||
| for i in range(self.param_length): | |||
| current_dynamic_lr = self.gather(self.learning_rate[i], self.global_step, 0) | |||
| lr += (current_dynamic_lr,) | |||
| F.control_depend(lr, self.assignadd(self.global_step, 1)) | |||
| else: | |||
| lr = self.learning_rate | |||
| if self.dynamic_lr: | |||
| lr = self.gather(self.learning_rate, self.global_step, 0) | |||
| F.control_depend(lr, self.assignadd(self.global_step, 1)) | |||
| return lr | |||
| def construct(self, *hyper_params): | |||
| @@ -22,17 +22,17 @@ rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | |||
| centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") | |||
| @rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") | |||
| def _rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad): | |||
| @rmsprop_opt.register("Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") | |||
| def _rmsprop_opt(opt, decay, epsilon, momentum, learning_rate, weight, ms, mom, grad): | |||
| """Apply rmsprop optimizer to the weight parameter using dynamic learning rate.""" | |||
| success = True | |||
| success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon)) | |||
| return success | |||
| @centered_rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", | |||
| @centered_rmsprop_opt.register("Function", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor", | |||
| "Tensor", "Tensor") | |||
| def _centered_rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad): | |||
| def _centered_rmsprop_opt(opt, decay, epsilon, momentum, learning_rate, weight, mg, ms, mom, grad): | |||
| """Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate.""" | |||
| success = True | |||
| success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon)) | |||
| @@ -44,6 +44,13 @@ class RMSProp(Optimizer): | |||
| Implements Root Mean Squared Propagation (RMSProp) algorithm. | |||
| Note: | |||
| The RMSProp optimizer supports separating parameter groups. Different parameter groups can set different | |||
| `learning_rate` and `weight_decay`. | |||
| When separating parameter groups, the weight decay in each group will be applied on the parameters if the | |||
| value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be | |||
| applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. | |||
| Update `params` according to the RMSProp algorithm. | |||
| The equation is as follows: | |||
| @@ -84,8 +91,18 @@ class RMSProp(Optimizer): | |||
| represents `gradients`. | |||
| Args: | |||
| params (list[Parameter]): A list of parameter, which will be updated. The element in `parameters` | |||
| should be class mindspore.Parameter. | |||
| params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, | |||
| the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", | |||
| "lr" and "weight_decay" are the keys can be parsed. | |||
| - params: Required. The value should be a list of `Parameter`. | |||
| - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. | |||
| If not, the `learning_rate` in the API will be used. | |||
| - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay | |||
| will be used. If not, the `weight_decay` in the API will be used. | |||
| learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is | |||
| Iterable or a Tensor and the dims of the Tensor is 1, | |||
| use dynamic learning rate, then the i-th step will | |||
| @@ -95,15 +112,13 @@ class RMSProp(Optimizer): | |||
| Other cases are not supported. Default: 0.1. | |||
| decay (float): Decay rate. Should be equal to or greater than 0. Default: 0.9. | |||
| momentum (float): Hyperparameter of type float, means momentum for the moving average. Should be equal to or | |||
| greater than 0.Default: 0.0. | |||
| greater than 0. Default: 0.0. | |||
| epsilon (float): Term added to the denominator to improve numerical stability. Should be greater than | |||
| 0. Default: 1e-10. | |||
| use_locking (bool): Enable a lock to protect the update of variable and accumlation tensors. Default: False. | |||
| centered (bool): If True, gradients are normalized by the estimated variance of the gradient. Default: False. | |||
| loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0. | |||
| weight_decay (float): Weight decay (L2 penalty). Should be equal to or greater than 0. Default: 0.0. | |||
| decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: | |||
| lambda x: 'beta' not in x.name and 'gamma' not in x.name. | |||
| Inputs: | |||
| - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`. | |||
| @@ -113,14 +128,25 @@ class RMSProp(Optimizer): | |||
| Examples: | |||
| >>> net = Net() | |||
| >>> #1) All parameters use the same learning rate and weight decay | |||
| >>> optim = nn.RMSProp(params=net.trainable_params(), learning_rate=lr) | |||
| >>> | |||
| >>> #2) Use parameter groups and set different values | |||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01}, | |||
| >>> {'params': no_conv_params}] | |||
| >>> opt = nn.RMSProp(group_params, learning_rate=0.1, weight_decay=0.0) | |||
| >>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01 | |||
| >>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a | |||
| >>> # learning rate of 0.1 and a weight decay of 0.0. | |||
| >>> | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| >>> opt = nn.RMSProp(params=net.trainable_params(), learning_rate=lr) | |||
| >>> model = Model(net, loss, opt) | |||
| >>> model = Model(net, loss_fn=loss, optimizer=optim) | |||
| """ | |||
| def __init__(self, params, learning_rate=0.1, decay=0.9, momentum=0.0, epsilon=1e-10, | |||
| use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0, | |||
| decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): | |||
| super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) | |||
| use_locking=False, centered=False, loss_scale=1.0, weight_decay=0.0): | |||
| super(RMSProp, self).__init__(learning_rate, params, weight_decay, loss_scale) | |||
| validator.check_value_type("decay", decay, [float], self.cls_name) | |||
| validator.check_number_range("decay", decay, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) | |||
| validator.check_value_type("momentum", momentum, [float], self.cls_name) | |||
| @@ -150,9 +176,18 @@ class RMSProp(Optimizer): | |||
| gradients = self.scale_grad(gradients) | |||
| lr = self.get_lr() | |||
| if self.centered: | |||
| success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, lr, self.decay, self.epsilon, | |||
| self.momentum), params, self.mg, self.ms, self.moment, gradients) | |||
| if self.is_group: | |||
| success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon, | |||
| self.momentum), lr, params, self.mg, self.ms, self.moment, gradients) | |||
| else: | |||
| success = self.hyper_map(F.partial(centered_rmsprop_opt, self.opt, self.decay, self.epsilon, | |||
| self.momentum, lr), params, self.mg, self.ms, self.moment, gradients) | |||
| else: | |||
| success = self.hyper_map(F.partial(rmsprop_opt, self.opt, lr, self.decay, self.epsilon, | |||
| self.momentum), params, self.ms, self.moment, gradients) | |||
| if self.is_group: | |||
| success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon, | |||
| self.momentum), lr, params, self.ms, self.moment, gradients) | |||
| else: | |||
| success = self.hyper_map(F.partial(rmsprop_opt, self.opt, self.decay, self.epsilon, | |||
| self.momentum, lr), params, self.ms, self.moment, gradients) | |||
| return success | |||
| @@ -24,7 +24,7 @@ sgd_opt = C.MultitypeFuncGraph("sgd_opt") | |||
| @sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") | |||
| def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, accum, stat): | |||
| def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, stat): | |||
| """Apply sgd optimizer to the weight parameter using Tensor.""" | |||
| success = True | |||
| success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat)) | |||
| @@ -39,9 +39,27 @@ class SGD(Optimizer): | |||
| Nesterov momentum is based on the formula from paper `On the importance of initialization and | |||
| momentum in deep learning <http://proceedings.mlr.press/v28/sutskever13.html>`_. | |||
| Note: | |||
| The SGD optimizer supports separating parameter groups. Different parameter groups can set different | |||
| `learning_rate` and `weight_decay`. | |||
| When separating parameter groups, the weight decay in each group will be applied on the parameters if the | |||
| value of weight_decay > 0. When not separating parameter groups, the `weight_decay` in the API will be | |||
| applied on the parameters if `weight_decay` > 0 and the 'beta' and 'gamma' are not in the name of parameters. | |||
| Args: | |||
| params (list[Parameter]): A list of parameter, which will be updated. The element in `params` | |||
| should be class mindspore.Parameter. | |||
| params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated, | |||
| the element in `params` should be class `Parameter`. When the `params` is a list of `dict`, the "params", | |||
| "lr" and "weight_decay" are the keys can be parsed. | |||
| - params: Required. The value should be a list of `Parameter`. | |||
| - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. | |||
| If not, the `learning_rate` in the API will be used. | |||
| - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay | |||
| will be used. If not, the `weight_decay` in the API will be used. | |||
| learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is | |||
| Iterable or a Tensor and the dims of the Tensor is 1, | |||
| use dynamic learning rate, then the i-th step will | |||
| @@ -67,9 +85,21 @@ class SGD(Optimizer): | |||
| Examples: | |||
| >>> net = Net() | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| >>> #1) All parameters use the same learning rate and weight decay | |||
| >>> optim = nn.SGD(params=net.trainable_params()) | |||
| >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | |||
| >>> | |||
| >>> #2) Use parameter groups and set different values | |||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'lr': 0.01}, | |||
| >>> {'params': no_conv_params}] | |||
| >>> opt = nn.SGD(group_params, learning_rate=0.1, weight_decay=0.0) | |||
| >>> # the conv_params's parameters will use a learning rate of 0.01 and a weight decay of 0.01 | |||
| >>> # the no_cov_params's parameters don't set learning and weight decay. So they will use a | |||
| >>> # learning rate of 0.1 and a weight decay of 0.0. | |||
| >>> | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| >>> model = Model(net, loss_fn=loss, optimizer=optim) | |||
| """ | |||
| def __init__(self, params, learning_rate=0.1, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=False, | |||
| loss_scale=1.0): | |||
| @@ -109,5 +139,8 @@ class SGD(Optimizer): | |||
| gradients = self.decay_weight(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| lr = self.get_lr() | |||
| success = self.hyper_map(F.partial(sgd_opt, self.opt, lr, self.momentum), gradients, params, accum, stat) | |||
| if self.is_group: | |||
| success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat) | |||
| else: | |||
| success = self.hyper_map(F.partial(sgd_opt, self.opt, self.momentum, lr), gradients, params, accum, stat) | |||
| return success | |||
| @@ -167,7 +167,7 @@ class TrainOneStepCell(Cell): | |||
| super(TrainOneStepCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.network.add_flags(defer_inline=True) | |||
| self.weights = ParameterTuple(network.trainable_params()) | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) | |||
| self.sens = sens | |||
| @@ -50,7 +50,7 @@ class NetWithoutWeight(nn.Cell): | |||
| def test_adamwithoutparam(): | |||
| net = NetWithoutWeight() | |||
| net.set_train() | |||
| with pytest.raises(ValueError, match=r"optimizer got an empty parameter list"): | |||
| with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"): | |||
| AdamWeightDecay(net.trainable_params(), learning_rate=0.1) | |||
| @@ -104,5 +104,5 @@ def test_AdamWeightDecayDynamicLR(): | |||
| def test_adam_mindspore_flatten(): | |||
| net = nn.Flatten() | |||
| with pytest.raises(ValueError, match=r"optimizer got an empty parameter list"): | |||
| with pytest.raises(ValueError, match=r"Optimizer got an empty parameter list"): | |||
| AdamWeightDecay(net.get_parameters()) | |||
| @@ -69,19 +69,19 @@ class TestSGD(): | |||
| class TestNullParam(): | |||
| """ TestNullParam definition """ | |||
| def test_optim_init(self): | |||
| with pytest.raises(TypeError): | |||
| with pytest.raises(ValueError): | |||
| Optimizer(0.1, None) | |||
| def test_AdamWightDecay_init(self): | |||
| with pytest.raises(TypeError): | |||
| with pytest.raises(ValueError): | |||
| AdamWeightDecay(None) | |||
| def test_AdamWeightDecayDynamicLR_init(self): | |||
| with pytest.raises(TypeError): | |||
| with pytest.raises(ValueError): | |||
| AdamWeightDecayDynamicLR(None, 10) | |||
| def test_Sgd_init(self): | |||
| with pytest.raises(TypeError): | |||
| with pytest.raises(ValueError): | |||
| SGD(None) | |||
| class TestUnsupportParam(): | |||
| @@ -0,0 +1,210 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.nn as nn | |||
| from mindspore.nn.optim import Momentum, SGD, RMSProp, Adam | |||
| from mindspore import context | |||
| from mindspore.common.api import _executor | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| class LeNet5(nn.Cell): | |||
| """ LeNet5 definition """ | |||
| def __init__(self): | |||
| super(LeNet5, self).__init__() | |||
| self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid') | |||
| self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') | |||
| self.fc1 = nn.Dense(16 * 5 * 5, 120) | |||
| self.fc2 = nn.Dense(120, 84) | |||
| self.fc3 = nn.Dense(84, 10) | |||
| self.relu = nn.ReLU() | |||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.flatten = P.Flatten() | |||
| def construct(self, x): | |||
| x = self.max_pool2d(self.relu(self.conv1(x))) | |||
| x = self.max_pool2d(self.relu(self.conv2(x))) | |||
| x = self.flatten(x) | |||
| x = self.relu(self.fc1(x)) | |||
| x = self.relu(self.fc2(x)) | |||
| x = self.fc3(x) | |||
| return x | |||
| def test_group_lr(): | |||
| inputs = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) | |||
| label = Tensor(np.ones([1, 10]).astype(np.float32)) | |||
| net = LeNet5() | |||
| conv_lr = 0.8 | |||
| default_lr = 0.1 | |||
| conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| group_params = [{'params': conv_params, 'lr': conv_lr}, | |||
| {'params': no_conv_params}] | |||
| net.set_train() | |||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| opt = Momentum(group_params, learning_rate=default_lr, momentum=0.9) | |||
| assert opt.is_group is True | |||
| assert opt.dynamic_lr is False | |||
| for lr, param in zip(opt.learning_rate, opt.parameters): | |||
| if param in conv_params: | |||
| assert lr.data == Tensor(conv_lr, mstype.float32) | |||
| else: | |||
| assert lr.data == Tensor(default_lr, mstype.float32) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(net_with_loss, opt) | |||
| _executor.compile(train_network, inputs, label) | |||
| def test_group_dynamic_1(): | |||
| inputs = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) | |||
| label = Tensor(np.ones([1, 10]).astype(np.float32)) | |||
| net = LeNet5() | |||
| conv_lr = 0.8 | |||
| default_lr = (0.1, 0.2, 0.3) | |||
| conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| group_params = [{'params': conv_params, 'lr': conv_lr}, | |||
| {'params': no_conv_params}] | |||
| net.set_train() | |||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| opt = Momentum(group_params, learning_rate=default_lr, momentum=0.9) | |||
| assert opt.is_group is True | |||
| assert opt.dynamic_lr is True | |||
| for lr, param in zip(opt.learning_rate, opt.parameters): | |||
| if param in conv_params: | |||
| assert lr.data == Tensor(np.array([conv_lr] * 3).astype(np.float32)) | |||
| else: | |||
| assert lr.data == Tensor(np.array(list(default_lr)).astype(np.float32)) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(net_with_loss, opt) | |||
| _executor.compile(train_network, inputs, label) | |||
| def test_group_dynamic_2(): | |||
| inputs = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) | |||
| label = Tensor(np.ones([1, 10]).astype(np.float32)) | |||
| net = LeNet5() | |||
| conv_lr = (0.1, 0.2, 0.3) | |||
| default_lr = 0.8 | |||
| conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| group_params = [{'params': conv_params, 'lr': conv_lr}, | |||
| {'params': no_conv_params}] | |||
| net.set_train() | |||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| opt = RMSProp(group_params, learning_rate=default_lr) | |||
| assert opt.is_group is True | |||
| assert opt.dynamic_lr is True | |||
| for lr, param in zip(opt.learning_rate, opt.parameters): | |||
| if param in conv_params: | |||
| assert lr.data == Tensor(np.array(list(conv_lr)).astype(np.float32)) | |||
| else: | |||
| assert lr.data == Tensor(np.array([default_lr] * 3).astype(np.float32)) | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(net_with_loss, opt) | |||
| _executor.compile(train_network, inputs, label) | |||
| def test_group_dynamic_no_same_size(): | |||
| net = LeNet5() | |||
| conv_lr = (0.1, 0.2, 0.3) | |||
| default_lr = (0.1, 0.2) | |||
| conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| group_params = [{'params': conv_params, 'lr': conv_lr}, | |||
| {'params': no_conv_params}] | |||
| with pytest.raises(ValueError): | |||
| Momentum(group_params, learning_rate=default_lr, momentum=0.9) | |||
| def test_group_not_float_lr(): | |||
| net = LeNet5() | |||
| conv_lr = 1 | |||
| default_lr = 0.3 | |||
| conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| group_params = [{'params': conv_params, 'lr': conv_lr}, | |||
| {'params': no_conv_params}] | |||
| with pytest.raises(TypeError): | |||
| Momentum(group_params, learning_rate=default_lr, momentum=0.9) | |||
| def test_group_not_float_weight_decay(): | |||
| net = LeNet5() | |||
| conv_weight_decay = 1 | |||
| conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay}, | |||
| {'params': no_conv_params}] | |||
| with pytest.raises(TypeError): | |||
| Momentum(group_params, learning_rate=0.1, momentum=0.9) | |||
| def test_weight_decay(): | |||
| inputs = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) | |||
| label = Tensor(np.ones([1, 10]).astype(np.float32)) | |||
| net = LeNet5() | |||
| conv_weight_decay = 0.8 | |||
| default_weight_decay = 0.0 | |||
| conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| group_params = [{'params': conv_params, 'weight_decay': conv_weight_decay}, | |||
| {'params': no_conv_params}] | |||
| net.set_train() | |||
| loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| opt = SGD(group_params, learning_rate=0.1, weight_decay=default_weight_decay) | |||
| assert opt.is_group is True | |||
| for weight_decay, decay_flags, param in zip(opt.weight_decay, opt.decay_flags, opt.parameters): | |||
| if param in conv_params: | |||
| assert weight_decay == conv_weight_decay | |||
| assert decay_flags is True | |||
| else: | |||
| assert weight_decay == default_weight_decay | |||
| assert decay_flags is False | |||
| net_with_loss = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(net_with_loss, opt) | |||
| _executor.compile(train_network, inputs, label) | |||
| def test_group_repeat_param(): | |||
| net = LeNet5() | |||
| conv_lr = 0.1 | |||
| default_lr = 0.3 | |||
| conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| group_params = [{'params': conv_params, 'lr': conv_lr}, | |||
| {'params': conv_params, 'lr': default_lr}, | |||
| {'params': no_conv_params}] | |||
| with pytest.raises(RuntimeError): | |||
| Adam(group_params, learning_rate=default_lr) | |||