|
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """adam"""
- import numpy as np
-
- from mindspore.common import dtype as mstype
- from mindspore.common.initializer import initializer
- from mindspore.ops import operations as P
- from mindspore.ops import composite as C
- from mindspore.ops import functional as F
- from mindspore.common.parameter import Parameter
- from mindspore.common.tensor import Tensor
- from mindspore._checkparam import Validator as validator
- from mindspore._checkparam import Rel
- from .optimizer import Optimizer
- from .optimizer import opt_init_args_register
-
- _adam_opt = C.MultitypeFuncGraph("adam_opt")
- _scaler_one = Tensor(1, mstype.int32)
- _scaler_ten = Tensor(10, mstype.float32)
-
-
- @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
- "Tensor", "Bool", "Bool")
- def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
- """
- Update parameters.
-
- Args:
- beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
- beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
- eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
- lr (Tensor): Learning rate.
- weight_decay (numbers.Number): Weight decay. Should be equal to or greater than 0.
- param (Tensor): Parameters.
- m (Tensor): m value of parameters.
- v (Tensor): v value of parameters.
- gradient (Tensor): Gradient of parameters.
- decay_flag (bool): Applies weight decay or not.
- optim_filter (bool): Applies parameter update or not.
-
- Returns:
- Tensor, the new value of v after updating.
- """
- op_cast = P.Cast()
- if optim_filter:
- op_mul = P.Mul()
- op_square = P.Square()
- op_sqrt = P.Sqrt()
- op_cast = P.Cast()
- op_reshape = P.Reshape()
- op_shape = P.Shape()
- param_fp32 = op_cast(param, mstype.float32)
- m_fp32 = op_cast(m, mstype.float32)
- v_fp32 = op_cast(v, mstype.float32)
- gradient_fp32 = op_cast(gradient, mstype.float32)
-
- next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- - beta1, gradient_fp32)
-
- next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- - beta2, op_square(gradient_fp32))
-
- update = next_m / (eps + op_sqrt(next_v))
- if decay_flag:
- update = op_mul(weight_decay, param_fp32) + update
-
- update_with_lr = op_mul(lr, update)
- next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
-
- next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
- next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
- next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
-
- return op_cast(next_param, F.dtype(param))
- return op_cast(gradient, F.dtype(param))
-
-
- @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
- "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
- def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
- beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable):
- """Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
- success = True
- indices = gradient.indices
- values = gradient.values
- if ps_parameter and not cache_enable:
- op_shape = P.Shape()
- shapes = (op_shape(param), op_shape(m), op_shape(v),
- op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
- op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
- success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
- eps, values, indices), shapes), param))
- return success
-
- if not target:
- success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2,
- eps, values, indices))
- else:
- op_mul = P.Mul()
- op_square = P.Square()
- op_sqrt = P.Sqrt()
- scatter_add = P.ScatterAdd(use_locking)
-
- success = F.depend(success, F.assign(m, op_mul(beta1, m)))
- success = F.depend(success, F.assign(v, op_mul(beta2, v)))
-
- grad_indices = gradient.indices
- grad_value = gradient.values
-
- next_m = scatter_add(m,
- grad_indices,
- op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
-
- next_v = scatter_add(v,
- grad_indices,
- op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value)))
-
- if use_nesterov:
- m_temp = next_m * _scaler_ten
- F.assign(m, op_mul(beta1, next_m))
- div_value = scatter_add(m,
- op_mul(grad_indices, _scaler_one),
- op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
- param_update = div_value / (op_sqrt(next_v) + eps)
- F.assign(m, m_temp / _scaler_ten)
- else:
- param_update = next_m / (op_sqrt(next_v) + eps)
-
- lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
- next_param = param - lr_t * param_update
-
- success = F.depend(success, F.assign(param, next_param))
- success = F.depend(success, F.assign(m, next_m))
- success = F.depend(success, F.assign(v, next_v))
-
- return success
-
-
- @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
- "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
- def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target,
- beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param,
- moment1, moment2, ps_parameter, cache_enable):
- """Apply adam optimizer to the weight parameter using Tensor."""
- success = True
- if ps_parameter and not cache_enable:
- op_shape = P.Shape()
- success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
- (op_shape(param), op_shape(moment1), op_shape(moment2))), param))
- else:
- success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
- eps, gradient))
- return success
-
-
- @_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
- "Tensor", "Tensor")
- def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
- """Apply AdamOffload optimizer to the weight parameter using Tensor."""
- success = True
- delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)
- success = F.depend(success, F.assign_add(param, delat_param))
- return success
-
-
- def _check_param_value(beta1, beta2, eps, prim_name):
- """Check the type of inputs."""
- validator.check_value_type("beta1", beta1, [float], prim_name)
- validator.check_value_type("beta2", beta2, [float], prim_name)
- validator.check_value_type("eps", eps, [float], prim_name)
- validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
- validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
- validator.check_positive_float(eps, "eps", prim_name)
-
-
- class Adam(Optimizer):
- r"""
- Updates gradients by the Adaptive Moment Estimation (Adam) algorithm.
-
- The Adam optimizer can dynamically adjust the learning rate of each parameter using the first-order
- moment estimation and the second-order moment estimation of the gradient.
- The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
-
- The updating formulas are as follows,
-
- .. math::
- \begin{gather*}
- m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\
- v_{t+1} = \beta_2 * v_{t} + (1 - \beta_2) * g * g \\
- l_{t+1} = l_{t} * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
- w_{t+1} = w_{t} - l * \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon}
- \end{gather*}
-
- :math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
- :math:`g` represents `gradients`, :math:`l` represents scaling factor, :math:`\beta_1, \beta_2` represent
- `beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
- `beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
- :math:`\epsilon` represents `eps`.
-
- Note:
- The sparse strategy is applied while the SparseGatherV2 operator is used for forward network. If the sparse
- strategy wants to be executed on the host, set the target to the CPU.
- The sparse feature is under continuous development.
-
- If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without
- 'beta' or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When
- parameters are grouped, each group can set `weight_decay`, if not, the `weight_decay` in optimizer will be
- applied.
-
- Args:
- params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
- `params` is a list of `dict`, the string "params", "lr", "weight_decay", "grad_centralization" and
- "order_params" are the keys can be parsed.
-
- - params: Required. Parameters in current group. The value must be a list of `Parameter`.
-
- - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
- If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
-
- - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
- will be used. If not, the `weight_decay` in the optimizer will be used.
-
- - grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value
- will be used. If not, the `grad_centralization` is False by default. This configuration only works on the
- convolution layer.
-
- - order_params: Optional. When parameters is grouped, this usually is used to maintain the order of
- parameters that appeared in the network to improve performance. The value should be parameters whose
- order will be followed in optimizer.
- If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
- one group of `params`.
-
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: 1e-3.
-
- - float: The fixed learning rate value. Must be equal to or greater than 0.
-
- - int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float.
-
- - Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied.
- For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate.
-
- - Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate.
-
- - LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of
- LearningRateSchedule with step as the input to get the learning rate of current step.
-
- beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
- Default: 0.9.
- beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
- Default: 0.999.
- eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
- 1e-8.
- use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
- If true, updates of the `w`, `m`, and `v` tensors will be protected by a lock.
- If false, the result is unpredictable. Default: False.
- use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
- If true, update the gradients using NAG.
- If false, update the gradients without using NAG. Default: False.
- weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
- loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the
- default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
- `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
- `FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
- Default: 1.0.
-
- Inputs:
- - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
-
- Outputs:
- Tensor[bool], the value is True.
-
- Raises:
- TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
- TypeError: If element of `parameters` is neither Parameter nor dict.
- TypeError: If `beta1`, `beta2`, `eps` or `loss_scale` is not a float.
- TypeError: If `weight_decay` is neither float nor int.
- TypeError: If `use_locking` or `use_nesterov` is not a bool.
- ValueError: If `loss_scale` or `eps` is less than or equal to 0.
- ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0).
- ValueError: If `weight_decay` is less than 0.
-
- Supported Platforms:
- ``Ascend`` ``GPU`` ``CPU``
-
- Examples:
- >>> net = Net()
- >>> #1) All parameters use the same learning rate and weight decay
- >>> optim = nn.Adam(params=net.trainable_params())
- >>>
- >>> #2) Use parameter groups and set different values
- >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
- >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
- >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
- ... {'params': no_conv_params, 'lr': 0.01},
- ... {'order_params': net.trainable_params()}]
- >>> optim = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0)
- >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
- >>> # centralization of True.
- >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
- >>> # centralization of False.
- >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
- >>>
- >>> loss = nn.SoftmaxCrossEntropyWithLogits()
- >>> model = Model(net, loss_fn=loss, optimizer=optim)
- """
-
- @opt_init_args_register
- def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
- use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
- super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale)
- _check_param_value(beta1, beta2, eps, self.cls_name)
- validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
- validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
-
- self.beta1 = Tensor(beta1, mstype.float32)
- self.beta2 = Tensor(beta2, mstype.float32)
- self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
- self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
- self.eps = Tensor(eps, mstype.float32)
- self.use_nesterov = use_nesterov
- self.use_locking = use_locking
- self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
- self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
-
- self._is_device = True
- self.opt = P.Adam(use_locking, use_nesterov)
- self.sparse_opt = P.FusedSparseAdam(use_locking, use_nesterov)
- self.sparse_opt.add_prim_attr("primitive_target", "CPU")
- self._ps_pull = P.Pull()
- self._ps_push = P.Push("Adam", [0, 1, 2])
- self._ps_push.add_prim_attr("use_nesterov", use_nesterov)
-
- def construct(self, gradients):
- params = self.parameters
- moment1 = self.moment1
- moment2 = self.moment2
- gradients = self.decay_weight(gradients)
- gradients = self.gradients_centralization(gradients)
- gradients = self.scale_grad(gradients)
- gradients = self._grad_sparse_indices_deduplicate(gradients)
- lr = self.get_lr()
-
- beta1_power = self.beta1_power * self.beta1
- self.beta1_power = beta1_power
- beta2_power = self.beta2_power * self.beta2
- self.beta2_power = beta2_power
- if self.is_group_lr:
- success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
- self.use_locking, self.use_nesterov, self._is_device,
- beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
- lr, gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable)
- else:
- success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
- self.use_locking, self.use_nesterov, self._is_device,
- beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
- gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable)
- return success
-
- @Optimizer.target.setter
- def target(self, value):
- """
- If the input value is set to "CPU", the parameters will be updated on the host using the Fused
- optimizer operation.
- """
- self._set_base_target(value)
-
-
- class AdamWeightDecay(Optimizer):
- r"""
- Implements the Adam algorithm to fix the weight decay.
-
- .. math::
- \begin{array}{ll} \\
- m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\
- v_{t+1} = \beta_2 * v_{t} + (1 - \beta_2) * g * g \\
- update = \frac{m_{t+1}}{\sqrt{v_{t+1}} + eps} \\
- update =
- \begin{cases}
- update + weight\_decay * w_{t}
- & \text{ if } weight\_decay > 0 \\
- update
- & \text{ otherwise }
- \end{cases} \\
- w_{t+1} = w_{t} - lr * update
- \end{array}
-
- :math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
- :math:`g` represents `gradients`, :math:`lr` represents `learning_rate`,
- :math:`\beta_1, \beta_2` represent `beta1` and `beta2`, :math:`t` represents updating step while
- :math:`w` represents `params`.
-
- Note:
- There is usually no connection between a optimizer and mixed precision. But when `FixedLossScaleManager` is used
- and `drop_overflow_update` in `FixedLossScaleManager` is set to False, optimizer needs to set the 'loss_scale'.
- As this optimizer has no argument of `loss_scale`, so `loss_scale` needs to be processed by other means, refer
- document `LossScale <https://www.mindspore.cn/docs/programming_guide/zh-CN/master/lossscale.html>`_ to process
- `loss_scale` correctly.
-
- If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without
- 'beta' or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When
- parameters are grouped, each group can set `weight_decay`, if not, the `weight_decay` in optimizer will be
- applied.
-
- Args:
- params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
- `params` is a list of `dict`, the string "params", "lr", "weight_decay", and "order_params"
- are the keys can be parsed.
-
- - params: Required. Parameters in current group. The value must be a list of `Parameter`.
-
- - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
- If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
-
- - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
- will be used. If not, the `weight_decay` in the optimizer will be used.
-
- - order_params: Optional. When parameters is grouped, this usually is used to maintain the order of
- parameters that appeared in the network to improve performance. The value should be parameters whose
- order will be followed in optimizer.
- If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
- one group of `params`.
-
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: 1e-3.
-
- - float: The fixed learning rate value. Must be equal to or greater than 0.
-
- - int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float.
-
- - Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied.
- For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate.
-
- - Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate.
-
- - LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of
- LearningRateSchedule with step as the input to get the learning rate of current step.
-
- beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
- Should be in range (0.0, 1.0).
- beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
- Should be in range (0.0, 1.0).
- eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
- Should be greater than 0.
- weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
-
- Inputs:
- - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
-
- Outputs:
- tuple[bool], all elements are True.
-
- Raises:
- TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
- TypeError: If element of `parameters` is neither Parameter nor dict.
- TypeError: If `beta1`, `beta2` or `eps` is not a float.
- TypeError: If `weight_decay` is neither float nor int.
- ValueError: If `eps` is less than or equal to 0.
- ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0).
- ValueError: If `weight_decay` is less than 0.
-
- Supported Platforms:
- ``Ascend`` ``GPU`` ``CPU``
-
- Examples:
- >>> net = Net()
- >>> #1) All parameters use the same learning rate and weight decay
- >>> optim = nn.AdamWeightDecay(params=net.trainable_params())
- >>>
- >>> #2) Use parameter groups and set different values
- >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
- >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
- >>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
- ... {'params': no_conv_params, 'lr': 0.01},
- ... {'order_params': net.trainable_params()}]
- >>> optim = nn.AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0)
- >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
- >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
- >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
- >>>
- >>> loss = nn.SoftmaxCrossEntropyWithLogits()
- >>> model = Model(net, loss_fn=loss, optimizer=optim)
- """
- def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
- super(AdamWeightDecay, self).__init__(learning_rate, params, weight_decay)
- _check_param_value(beta1, beta2, eps, self.cls_name)
- self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
- self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
- self.eps = Tensor(np.array([eps]).astype(np.float32))
- self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
- self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
-
- def construct(self, gradients):
- lr = self.get_lr()
- if self.is_group:
- if self.is_group_lr:
- optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
- lr, self.weight_decay, self.parameters, self.moments1,
- self.moments2, gradients, self.decay_flags, self.optim_filter)
- else:
- optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
- self.weight_decay, self.parameters, self.moments1, self.moments2,
- gradients, self.decay_flags, self.optim_filter)
- else:
- optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr,
- self.weight_decay),
- self.parameters, self.moments1, self.moments2,
- gradients, self.decay_flags, self.optim_filter)
- if self.use_parallel:
- self.broadcast_params(optim_result)
- return optim_result
-
-
- class AdamOffload(Optimizer):
- r"""
- This optimizer will offload Adam optimizer to host CPU and keep parameters being updated on the device,
- to minimize the memory cost. Although that would bring about an increase of performance overhead,
- the optimizer could be used to run a larger model.
-
- The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
-
- The updating formulas are as follows,
-
- .. math::
- \begin{array}{ll} \\
- m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\
- v_{t+1} = \beta_2 * v_{t} + (1 - \beta_2) * g * g \\
- l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
- w_{t+1} = w_{t} - l * \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon}
- \end{array}
-
- :math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
- :math:`g` represents `gradients`, :math:`l` represents scaling factor, :math:`\beta_1, \beta_2` represent
- `beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
- `beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
- :math:`\epsilon` represents `eps`.
-
- Note:
- This optimizer only supports `GRAPH_MODE` currently.
-
- If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without
- 'beta' or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When
- parameters are grouped, each group can set `weight_decay`, if not, the `weight_decay` in optimizer will be
- applied.
-
- Args:
- params (Union[list[Parameter], list[dict]]): Must be list of `Parameter` or list of `dict`. When the
- `params` is a list of `dict`, the string "params", "lr", "weight_decay", and "order_params"
- are the keys can be parsed.
-
- - params: Required. Parameters in current group. The value must be a list of `Parameter`.
-
- - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
- If not, the `learning_rate` in optimizer will be used. Fixed and dynamic learning rate are supported.
-
- - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
- will be used. If not, the `weight_decay` in the optimizer will be used.
-
- - order_params: Optional. When parameters is grouped, this usually is used to maintain the order of
- parameters that appeared in the network to improve performance. The value should be parameters whose
- order will be followed in optimizer.
- If `order_params` in the keys, other keys will be ignored and the element of 'order_params' must be in
- one group of `params`.
-
- learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]): Default: 1e-3.
-
- - float: The fixed learning rate value. Must be equal to or greater than 0.
-
- - int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float.
-
- - Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied.
- For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate.
-
- - Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate.
-
- - LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of
- LearningRateSchedule with step as the input to get the learning rate of current step.
-
- beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
- Default: 0.9.
- beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
- Default: 0.999.
- eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
- 1e-8.
- use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
- If true, updates of the `w`, `m`, and `v` tensors will be protected by a lock.
- If false, the result is unpredictable. Default: False.
- use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
- If true, update the gradients using NAG.
- If false, update the gradients without using NAG. Default: False.
- weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
- loss_scale (float): A floating point value for the loss scale. Should be greater than 0. In general, use the
- default value. Only when `FixedLossScaleManager` is used for training and the `drop_overflow_update` in
- `FixedLossScaleManager` is set to False, then this value needs to be the same as the `loss_scale` in
- `FixedLossScaleManager`. Refer to class :class:`mindspore.FixedLossScaleManager` for more details.
- Default: 1.0.
-
- Inputs:
- - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
-
- Outputs:
- Tensor[bool], the value is True.
-
- Raises:
- TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
- TypeError: If element of `parameters` is neither Parameter nor dict.
- TypeError: If `beta1`, `beta2`, `eps` or `loss_scale` is not a float.
- TypeError: If `weight_decay` is neither float nor int.
- TypeError: If `use_locking` or `use_nesterov` is not a bool.
- ValueError: If `loss_scale` or `eps` is less than or equal to 0.
- ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0).
- ValueError: If `weight_decay` is less than 0.
-
- Supported Platforms:
- ``Ascend`` ``GPU`` ``CPU``
-
- Examples:
- >>> net = Net()
- >>> #1) All parameters use the same learning rate and weight decay
- >>> optim = nn.AdamOffload(params=net.trainable_params())
- >>>
- >>> #2) Use parameter groups and set different values
- >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
- >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
- >>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
- ... {'params': no_conv_params, 'lr': 0.01},
- ... {'order_params': net.trainable_params()}]
- >>> optim = nn.AdamOffload(group_params, learning_rate=0.1, weight_decay=0.0)
- >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
- >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
- >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
- >>>
- >>> loss = nn.SoftmaxCrossEntropyWithLogits()
- >>> model = Model(net, loss_fn=loss, optimizer=optim)
- """
-
- def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
- use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
- super(AdamOffload, self).__init__(learning_rate, params, weight_decay, loss_scale)
- _check_param_value(beta1, beta2, eps, self.cls_name)
- validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
- validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
-
- self.beta1 = Tensor(beta1, mstype.float32)
- self.beta2 = Tensor(beta2, mstype.float32)
- self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
- self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
- self.eps = Tensor(eps, mstype.float32)
- self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
- self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
- self.opt = P.AdamNoUpdateParam(use_locking, use_nesterov)
- self.opt.add_prim_attr("primitive_target", "CPU")
-
- def construct(self, gradients):
- params = self.parameters
- moment1 = self.moment1
- moment2 = self.moment2
- gradients = self.decay_weight(gradients)
- gradients = self.scale_grad(gradients)
- lr = self.get_lr()
-
- beta1_power = self.beta1_power * self.beta1
- self.beta1_power = beta1_power
- beta2_power = self.beta2_power * self.beta2
- self.beta2_power = beta2_power
- if self.is_group_lr:
- success = self.map_reverse(F.partial(_adam_opt, self.opt,
- beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
- lr, gradients, params, moment1, moment2)
- else:
- success = self.map_reverse(F.partial(_adam_opt, self.opt,
- beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
- gradients, params, moment1, moment2)
- return success
|