From: @lijiaqi0612 Reviewed-by: @kingxian,@zh_qh Signed-off-by: @kingxiantags/v1.2.0-rc1
| @@ -67,6 +67,10 @@ class Adagrad(Optimizer): | |||
| the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which | |||
| in the value of 'order_params' must be in one of group parameters. | |||
| - grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used. | |||
| If not, the `grad_centralization` in the base class will be used. This parameter only works on the | |||
| convolution layer. | |||
| accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1. | |||
| learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. | |||
| When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then | |||
| @@ -98,12 +102,14 @@ class Adagrad(Optimizer): | |||
| >>> #2) Use parameter groups and set different values | |||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True}, | |||
| ... {'params': no_conv_params, 'lr': 0.01}, | |||
| ... {'order_params': net.trainable_params()}] | |||
| >>> optim = nn.Adagrad(group_params, learning_rate=0.1, weight_decay=0.0) | |||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. | |||
| >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. | |||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad | |||
| >>> # centralization of True. | |||
| >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad | |||
| >>> # centralization of False. | |||
| >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. | |||
| >>> | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| @@ -124,6 +130,7 @@ class Adagrad(Optimizer): | |||
| accum = self.accum | |||
| grads = self.decay_weight(grads) | |||
| grads = self.scale_grad(grads) | |||
| grads = self.gradients_centralization(grads) | |||
| lr = self.get_lr() | |||
| if self.is_group_lr: | |||
| success = self.map_(F.partial(_ada_grad_opt, self.opt), lr, params, accum, | |||
| @@ -235,6 +235,10 @@ class Adam(Optimizer): | |||
| the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters | |||
| which in the 'order_params' must be in one of group parameters. | |||
| - grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used. | |||
| If not, the `grad_centralization` in the base class will be used. This parameter only works on the | |||
| convolution layer. | |||
| learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. | |||
| When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then | |||
| the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, | |||
| @@ -275,12 +279,14 @@ class Adam(Optimizer): | |||
| >>> #2) Use parameter groups and set different values | |||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True}, | |||
| ... {'params': no_conv_params, 'lr': 0.01}, | |||
| ... {'order_params': net.trainable_params()}] | |||
| >>> optim = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0) | |||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. | |||
| >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. | |||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad | |||
| >>> # centralization of True. | |||
| >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad | |||
| >>> # centralization of False. | |||
| >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. | |||
| >>> | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| @@ -320,6 +326,7 @@ class Adam(Optimizer): | |||
| gradients = self.decay_weight(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| gradients = self._grad_sparse_indices_deduplicate(gradients) | |||
| gradients = self.gradients_centralization(gradients) | |||
| lr = self.get_lr() | |||
| beta1_power = self.beta1_power * self.beta1 | |||
| @@ -97,7 +97,7 @@ class FTRL(Optimizer): | |||
| \end{cases}\\ | |||
| \end{array} | |||
| :math:`m` represents `accum`, :math:`g` represents `grads`, :math:`t` represents updateing step, | |||
| :math:`m` represents `accum`, :math:`g` represents `grads`, :math:`t` represents updating step, | |||
| :math:`u` represents `linear`, :math:`p` represents `lr_power`, :math:`\alpha` represents `learning_rate`, | |||
| :math:`\omega` represents `params`. | |||
| @@ -128,6 +128,10 @@ class FTRL(Optimizer): | |||
| the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which | |||
| in the value of 'order_params' must be in one of group parameters. | |||
| - grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used. | |||
| If not, the `grad_centralization` in the base class will be used.This parameter only works on the | |||
| convolution layer. | |||
| initial_accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1. | |||
| learning_rate (float): The learning rate value, must be zero or positive, dynamic learning rate is currently | |||
| not supported. Default: 0.001. | |||
| @@ -157,12 +161,13 @@ class FTRL(Optimizer): | |||
| >>> #2) Use parameter groups and set different values | |||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True}, | |||
| ... {'params': no_conv_params}, | |||
| ... {'order_params': net.trainable_params()}] | |||
| >>> optim = nn.FTRL(group_params, learning_rate=0.1, weight_decay=0.0) | |||
| >>> # The conv_params's parameters will use weight decay of 0.01. | |||
| >>> # The no_conv_params's parameters will use default weight decay of 0.0. | |||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad | |||
| >>> # centralization of True. | |||
| >>> # The no_conv_params's parameters will use default weight decay of 0.0 and grad centralization of False. | |||
| >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. | |||
| >>> | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| @@ -201,6 +206,7 @@ class FTRL(Optimizer): | |||
| grads = self.decay_weight(grads) | |||
| grads = self.scale_grad(grads) | |||
| grads = self._grad_sparse_indices_deduplicate(grads) | |||
| grads = self.gradients_centralization(grads) | |||
| lr = self.get_lr() | |||
| success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | |||
| @@ -200,6 +200,10 @@ class Lamb(Optimizer): | |||
| the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which | |||
| in the value of 'order_params' must be in one of group parameters. | |||
| - grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used. | |||
| If not, the `grad_centralization` in the base class will be used. This parameter only works on the | |||
| convolution layer. | |||
| learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. | |||
| When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then | |||
| the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, | |||
| @@ -234,13 +238,14 @@ class Lamb(Optimizer): | |||
| ... decay_steps=4, power = 0.5) | |||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True}, | |||
| ... {'params': no_conv_params, 'lr': poly_decay_lr}, | |||
| ... {'order_params': net.trainable_params(0.01)}] | |||
| >>> optim = nn.Lamb(group_params, learning_rate=0.1, weight_decay=0.0) | |||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. | |||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad | |||
| >>> # centralization of True. | |||
| >>> # The no_conv_params's parameters will use dynamic learning rate of poly decay learning rate and default | |||
| >>> # weight decay of 0.0. | |||
| >>> # weight decay of 0.0 and grad centralization of False. | |||
| >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. | |||
| >>> | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| @@ -268,6 +273,7 @@ class Lamb(Optimizer): | |||
| def construct(self, gradients): | |||
| lr = self.get_lr() | |||
| lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt | |||
| gradients = self.gradients_centralization(gradients) | |||
| if self.is_group: | |||
| if self.is_group_lr: | |||
| optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, | |||
| @@ -154,6 +154,10 @@ class LazyAdam(Optimizer): | |||
| the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which | |||
| in the value of 'order_params' must be in one of group parameters. | |||
| - grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used. | |||
| If not, the `grad_centralization` in the base class will be used. This parameter only works on the | |||
| convolution layer. | |||
| learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. | |||
| When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then | |||
| the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, | |||
| @@ -195,12 +199,14 @@ class LazyAdam(Optimizer): | |||
| >>> #2) Use parameter groups and set different values | |||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True}, | |||
| ... {'params': no_conv_params, 'lr': 0.01}, | |||
| ... {'order_params': net.trainable_params()}] | |||
| >>> optim = nn.LazyAdam(group_params, learning_rate=0.1, weight_decay=0.0) | |||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. | |||
| >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. | |||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad | |||
| >>> # centralization of True. | |||
| >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad | |||
| >>> # centralization of False. | |||
| >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. | |||
| >>> | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| @@ -237,6 +243,7 @@ class LazyAdam(Optimizer): | |||
| gradients = self.decay_weight(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| gradients = self._grad_sparse_indices_deduplicate(gradients) | |||
| gradients = self.gradients_centralization(gradients) | |||
| lr = self.get_lr() | |||
| self.beta1_power = self.beta1_power * self.beta1 | |||
| @@ -83,6 +83,10 @@ class Momentum(Optimizer): | |||
| the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which | |||
| in the value of 'order_params' must be in one of group parameters. | |||
| - grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used. | |||
| If not, the `grad_centralization` in the base class will be used. This parameter only works on the | |||
| convolution layer. | |||
| learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. | |||
| When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then | |||
| the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, | |||
| @@ -117,12 +121,14 @@ class Momentum(Optimizer): | |||
| >>> #2) Use parameter groups and set different values | |||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True}, | |||
| ... {'params': no_conv_params, 'lr': 0.01}, | |||
| ... {'order_params': net.trainable_params()}] | |||
| >>> optim = nn.Momentum(group_params, learning_rate=0.1, momentum=0.9, weight_decay=0.0) | |||
| >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. | |||
| >>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. | |||
| >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01 and | |||
| >>> # grad centralization of True. | |||
| >>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0 | |||
| >>> # and grad centralization of False.. | |||
| >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. | |||
| >>> | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| @@ -145,6 +151,7 @@ class Momentum(Optimizer): | |||
| moments = self.moments | |||
| gradients = self.decay_weight(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| gradients = self.gradients_centralization(gradients) | |||
| lr = self.get_lr() | |||
| if self.is_group_lr: | |||
| success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments, | |||
| @@ -19,6 +19,7 @@ import numpy as np | |||
| import mindspore | |||
| from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| from mindspore.nn.cell import Cell | |||
| from mindspore.nn.layer.container import CellList | |||
| from mindspore.common.parameter import Parameter, ParameterTuple | |||
| @@ -43,12 +44,16 @@ class Optimizer(Cell): | |||
| This class defines the API to add Ops to train a model. Never use | |||
| this class directly, but instead instantiate one of its subclasses. | |||
| Different parameter groups can set different `learning_rate` and `weight_decay`. | |||
| Different parameter groups can set different `learning_rate`, `weight_decay` and `grad_centralization`. | |||
| When separating parameter groups, the weight decay in each group will be applied on the parameters if the | |||
| weight_decay is positive. For most optimizer, when not separating parameters, the `weight_decay` in the API will | |||
| be applied on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive. | |||
| When separating parameter groups, if you want to centralize the gradient, set a to True, but the gradient | |||
| centralization can only be applied to the parameters of the convolution layer. If the parameters of the non | |||
| convolution layer are set to True, an error will be reported. Default: False. | |||
| To improve parameter groups performance, the customized order of parameters can be supported. | |||
| Args: | |||
| @@ -75,6 +80,9 @@ class Optimizer(Cell): | |||
| the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which | |||
| in the value of 'order_params' must be in one of group parameters. | |||
| - grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used. | |||
| If not, the `grad_centralization` in the base class will be used. | |||
| weight_decay (float): A floating point value for the weight decay. It must be equal to or greater than 0. | |||
| If the type of `weight_decay` input is int, it will be converted to float. Default: 0.0. | |||
| loss_scale (float): A floating point value for the loss scale. It must be greater than 0. If the | |||
| @@ -106,6 +114,7 @@ class Optimizer(Cell): | |||
| self.loss_scale = loss_scale | |||
| weight_decay = self._preprocess_weight_decay(weight_decay) | |||
| self.grad_centralization = False | |||
| self._unique = True | |||
| self._target = context.get_context("device_target") | |||
| @@ -121,7 +130,8 @@ class Optimizer(Cell): | |||
| self.group_params = [] | |||
| self.group_lr = [] | |||
| self.group_weight_decay = [] | |||
| self._init_group_params(parameters, learning_rate, weight_decay) | |||
| self.group_grad_centralization = [] | |||
| self._init_group_params(parameters, learning_rate, weight_decay, self.grad_centralization) | |||
| # The final value of dynamic_lr can be determined after the process of parse_single_lr and init_group_params | |||
| if self.dynamic_lr: | |||
| @@ -129,12 +139,10 @@ class Optimizer(Cell): | |||
| self.global_step = Parameter(initializer(0, [1], mindspore.int32), name='global_step') | |||
| if self.is_group_lr: | |||
| if self.dynamic_lr: | |||
| self.learning_rate = CellList(self.group_lr) | |||
| else: | |||
| self.learning_rate = ParameterTuple(self.group_lr) | |||
| self.learning_rate = CellList(self.group_lr) if self.dynamic_lr else ParameterTuple(self.group_lr) | |||
| else: | |||
| self.learning_rate = self._build_single_lr(learning_rate, 'learning_rate') | |||
| if self.is_group: | |||
| self.parameters = ParameterTuple(self.group_params) | |||
| self.weight_decay = tuple(self.group_weight_decay) | |||
| @@ -142,6 +150,7 @@ class Optimizer(Cell): | |||
| decay_filter = lambda x: x > 0 | |||
| self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) | |||
| self.exec_weight_decay = any(self.decay_flags) | |||
| self.grad_centralization_flags = tuple(self.group_grad_centralization) | |||
| else: | |||
| self.parameters = ParameterTuple(parameters) | |||
| self.weight_decay = weight_decay * loss_scale | |||
| @@ -163,6 +172,10 @@ class Optimizer(Cell): | |||
| self.global_step_increase_tensor = Tensor(1, mstype.int32) | |||
| self.param_length = len(self.parameters) | |||
| self.map_ = C.Map() | |||
| self._use_parallel_optimizer() | |||
| def _use_parallel_optimizer(self): | |||
| """Indicates whether to use automatic parallelism.""" | |||
| if context.get_auto_parallel_context("enable_parallel_optimizer"): | |||
| if _get_parallel_mode() == ParallelMode.DATA_PARALLEL and context.get_context("device_target") == "Ascend": | |||
| self.use_parallel = True | |||
| @@ -187,7 +200,6 @@ class Optimizer(Cell): | |||
| self.param_names = [] | |||
| for param in self.parameters: | |||
| self.param_names.append(param.name) | |||
| else: | |||
| self.optim_filter = (True,) * self.param_length | |||
| @@ -239,6 +251,25 @@ class Optimizer(Cell): | |||
| return gradients | |||
| def gradients_centralization(self, gradients): | |||
| """ | |||
| Gradients centralization. | |||
| A method for optimizing convolutional layer parameters to impore the training speed of a deep learning neural | |||
| network model. | |||
| Args: | |||
| gradients (tuple[Tensor]): The gradients of `self.parameters`, and have the same shape as | |||
| `self.parameters`. | |||
| Returns: | |||
| tuple[Tensor], The gradients after gradients centralization. | |||
| """ | |||
| if self.is_group: | |||
| gradients = self.map_(F.partial(_apply_grad_centralization), self.grad_centralization_flags, gradients) | |||
| return gradients | |||
| def scale_grad(self, gradients): | |||
| """ | |||
| Loss scale for mixed precision. | |||
| @@ -273,6 +304,11 @@ class Optimizer(Cell): | |||
| return weight_decay | |||
| raise TypeError("Weight decay should be int or float.") | |||
| def _preprocess_grad_centralization(self, grad_centralization): | |||
| if not isinstance(grad_centralization, bool): | |||
| raise TypeError("The gradients centralization should be bool") | |||
| return grad_centralization | |||
| def _preprocess_single_lr(self, learning_rate): | |||
| """Check lr value, and convert lr to a float, a Tensor or a LearningRateSchedule.""" | |||
| if isinstance(learning_rate, (float, int)): | |||
| @@ -315,7 +351,7 @@ class Optimizer(Cell): | |||
| def _check_group_params(self, parameters): | |||
| """Check group params.""" | |||
| parse_keys = ['params', 'lr', 'weight_decay', 'order_params'] | |||
| parse_keys = ['params', 'lr', 'weight_decay', 'order_params', 'grad_centralization'] | |||
| for group_param in parameters: | |||
| invalid_key = list(filter(lambda x: x not in parse_keys, group_param.keys())) | |||
| if invalid_key: | |||
| @@ -365,8 +401,8 @@ class Optimizer(Cell): | |||
| elif group_lr_length != tensor_lr_length: | |||
| raise ValueError("The Tensor type dynamic learning rate in group should be the same size.") | |||
| def _init_group_params(self, parameters, learning_rate, weight_decay): | |||
| """Initialize learning rate or weight decay in group params.""" | |||
| def _init_group_params(self, parameters, learning_rate, weight_decay, grad_centralization): | |||
| """Initialize learning rate, weight decay or grad centralization in group params.""" | |||
| self._parse_group_params(parameters, learning_rate) | |||
| default_lr = self._build_single_lr(learning_rate, 'learning_rate') | |||
| @@ -391,8 +427,20 @@ class Optimizer(Cell): | |||
| else: | |||
| weight_decay_ = weight_decay * self.loss_scale | |||
| if 'grad_centralization' in group_param.keys(): | |||
| self.grad_centralization = self._preprocess_grad_centralization(group_param['grad_centralization']) | |||
| for param in group_param['params']: | |||
| validator.check_value_type("parameter", param, [Parameter], self.cls_name) | |||
| if "conv" not in param.name and self.grad_centralization is True: | |||
| raise ValueError("Grad centralization can be perform only on the conv layer. If the parameter" | |||
| "is not a convolution layer, this parameter cannot be set to True.") | |||
| grad_centralization_ = self.grad_centralization | |||
| else: | |||
| grad_centralization_ = grad_centralization | |||
| for key in group_param.keys(): | |||
| if key not in ('params', 'lr', 'weight_decay'): | |||
| if key not in ('params', 'lr', 'weight_decay', 'grad_centralization'): | |||
| logger.warning(f"The optimizer cannot parse '{key}' when setting parameter groups.") | |||
| for param in group_param['params']: | |||
| @@ -403,13 +451,14 @@ class Optimizer(Cell): | |||
| params_store.append(param.name) | |||
| self.group_lr.append(lr) | |||
| self.group_weight_decay.append(weight_decay_) | |||
| self.group_grad_centralization.append(grad_centralization_) | |||
| if self.is_group_params_ordered: | |||
| self._order_and_adjust_group_params(ordered_parameters) | |||
| def _order_and_adjust_group_params(self, ordered_parameters): | |||
| """ | |||
| Order group parameter, learning rate and weight decay in group params. | |||
| Order group parameter, learning rate, weight decay and grad centralization in group params. | |||
| """ | |||
| params_length = len(self.group_params) | |||
| if len(ordered_parameters) != len(self.group_params): | |||
| @@ -418,17 +467,21 @@ class Optimizer(Cell): | |||
| ordered_params = [None] * params_length | |||
| ordered_learning_rate = [None] * params_length | |||
| ordered_weight_decay = [None] * params_length | |||
| ordered_grad_centralization = [None] * params_length | |||
| params_name = [param.name for param in ordered_parameters] | |||
| for param, lr, wd in zip(self.group_params, self.group_lr, self.group_weight_decay): | |||
| for param, lr, wd, gc in zip(self.group_params, self.group_lr, self.group_weight_decay, | |||
| self.group_grad_centralization): | |||
| index = params_name.index(param.name) | |||
| ordered_params[index] = param | |||
| ordered_learning_rate[index] = lr | |||
| ordered_weight_decay[index] = wd | |||
| ordered_grad_centralization[index] = gc | |||
| self.group_params = ordered_params | |||
| self.group_lr = ordered_learning_rate | |||
| self.group_weight_decay = ordered_weight_decay | |||
| self.group_grad_centralization = ordered_grad_centralization | |||
| def get_lr(self): | |||
| """ | |||
| @@ -535,8 +588,10 @@ class Optimizer(Cell): | |||
| op_add = P.AddN() | |||
| op_gather = P.Gather() | |||
| op_mul = P.Mul() | |||
| op_gc = inner.Centralization() | |||
| _apply_decay = C.MultitypeFuncGraph("apply_decay") | |||
| _apply_grad_centralization = C.MultitypeFuncGraph("apply_grad_centralization") | |||
| @_apply_decay.register("Tensor", "Bool", "Tensor", "RowTensor") | |||
| @@ -558,9 +613,18 @@ def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): | |||
| return gradient | |||
| @_apply_grad_centralization.register("Bool", "Tensor") | |||
| def _tensor_apply_grad_centralization(if_apply, gradient): | |||
| """Get grad with grad_centralization.""" | |||
| if if_apply: | |||
| return op_gc(gradient, -1) | |||
| return gradient | |||
| _grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
| _indices_deduplicate = C.MultitypeFuncGraph("indices_deduplicate") | |||
| @_grad_scale.register("Number", "Tensor") | |||
| def tensor_grad_scale(scale, grad): | |||
| """Get grad with scale.""" | |||
| @@ -568,11 +632,13 @@ def tensor_grad_scale(scale, grad): | |||
| return grad | |||
| return op_mul(grad, F.cast(scale, F.dtype(grad))) | |||
| @_grad_scale.register("Tensor", "Tensor") | |||
| def tensor_grad_scale_with_tensor(scale, grad): | |||
| """Get grad with scale.""" | |||
| return op_mul(grad, F.cast(scale, F.dtype(grad))) | |||
| @_grad_scale.register("Tensor", "RowTensor") | |||
| def tensor_grad_scale_with_sparse(scale, grad): | |||
| """Get grad with scale.""" | |||
| @@ -85,6 +85,10 @@ class ProximalAdagrad(Optimizer): | |||
| the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which | |||
| in the value of 'order_params' must be in one of group parameters. | |||
| - grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used. | |||
| If not, the `grad_centralization` in the base class will be used. This parameter only works on the | |||
| convolution layer. | |||
| accum (float): The starting value for accumulators, must be zero or positive values. Default: 0.1. | |||
| learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. | |||
| When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then | |||
| @@ -118,12 +122,14 @@ class ProximalAdagrad(Optimizer): | |||
| >>> #2) Use parameter groups and set different values | |||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True}, | |||
| ... {'params': no_conv_params, 'lr': 0.01}, | |||
| ... {'order_params': net.trainable_params()}] | |||
| >>> optim = nn.ProximalAdagrad(group_params, learning_rate=0.1, weight_decay=0.0) | |||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. | |||
| >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. | |||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad | |||
| >>> # centralization of True. | |||
| >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad | |||
| >>> # centralization of False. | |||
| >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. | |||
| >>> | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| @@ -148,6 +154,7 @@ class ProximalAdagrad(Optimizer): | |||
| grads = self.decay_weight(grads) | |||
| grads = self.scale_grad(grads) | |||
| grads = self._grad_sparse_indices_deduplicate(grads) | |||
| grads = self.gradients_centralization(grads) | |||
| lr = self.get_lr() | |||
| if self.is_group_lr: | |||
| success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr, | |||
| @@ -105,6 +105,10 @@ class RMSProp(Optimizer): | |||
| the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which | |||
| in the value of 'order_params' must be in one of group parameters. | |||
| - grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used. | |||
| If not, the `grad_centralization` in the base class will be used. This parameter only works on the | |||
| convolution layer. | |||
| learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. | |||
| When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then | |||
| the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, | |||
| @@ -141,12 +145,14 @@ class RMSProp(Optimizer): | |||
| >>> #2) Use parameter groups and set different values | |||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True}, | |||
| ... {'params': no_conv_params, 'lr': 0.01}, | |||
| ... {'order_params': net.trainable_params()}] | |||
| >>> optim = nn.RMSProp(group_params, learning_rate=0.1, weight_decay=0.0) | |||
| >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. | |||
| >>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. | |||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad | |||
| >>> # centralization of True. | |||
| >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad | |||
| >>> # centralization of False. | |||
| >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. | |||
| >>> | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| @@ -182,6 +188,7 @@ class RMSProp(Optimizer): | |||
| params = self.parameters | |||
| gradients = self.decay_weight(gradients) | |||
| gradients = self.scale_grad(gradients) | |||
| gradients = self.gradients_centralization(gradients) | |||
| lr = self.get_lr() | |||
| if self.centered: | |||
| if self.is_group_lr: | |||
| @@ -80,6 +80,10 @@ class SGD(Optimizer): | |||
| the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which | |||
| in the value of 'order_params' must be in one of group parameters. | |||
| - grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used. | |||
| If not, the `grad_centralization` in the base class will be used. This parameter only works on the | |||
| convolution layer. | |||
| learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate. | |||
| When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then | |||
| the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, | |||
| @@ -116,12 +120,14 @@ class SGD(Optimizer): | |||
| >>> #2) Use parameter groups and set different values | |||
| >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params())) | |||
| >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params())) | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01}, | |||
| >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True}, | |||
| ... {'params': no_conv_params, 'lr': 0.01}, | |||
| ... {'order_params': net.trainable_params()}] | |||
| >>> optim = nn.SGD(group_params, learning_rate=0.1, weight_decay=0.0) | |||
| >>> # The conv_params's parameters will use a learning rate of default value 0.1 and a weight decay of 0.01. | |||
| >>> # The no_conv_params's parameters will use a learning rate of 0.01 and a weight decay of default value 0.0. | |||
| >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad | |||
| >>> # centralization of True. | |||
| >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad | |||
| >>> # centralization of False. | |||
| >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. | |||
| >>> | |||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||
| @@ -171,6 +177,7 @@ class SGD(Optimizer): | |||
| accum = self.accum | |||
| stat = self.stat | |||
| gradients = self.scale_grad(gradients) | |||
| gradients = self.gradients_centralization(gradients) | |||
| lr = self.get_lr() | |||
| if self.is_group_lr: | |||
| success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat) | |||