You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lazyadam.py 15 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """lazy adam"""
  16. from mindspore.common import dtype as mstype
  17. from mindspore.common.initializer import initializer
  18. from mindspore.ops import operations as P
  19. from mindspore.ops import composite as C
  20. from mindspore.ops import functional as F
  21. from mindspore.common.parameter import Parameter
  22. from mindspore.common.tensor import Tensor
  23. from mindspore._checkparam import Validator as validator
  24. from mindspore._checkparam import Rel
  25. from .optimizer import Optimizer
  26. _lazy_adam_opt = C.MultitypeFuncGraph("lazy_adam_opt")
  27. @_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
  28. "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool")
  29. def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, beta2_power,
  30. beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter):
  31. """Apply sparse lazy adam optimizer to the weight parameter when the gradient is sparse."""
  32. success = True
  33. indices = gradient.indices
  34. values = gradient.values
  35. if ps_parameter:
  36. op_shape = P.Shape()
  37. shapes = (op_shape(params), op_shape(m), op_shape(v),
  38. op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
  39. op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
  40. success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
  41. eps, values, indices), shapes), params))
  42. return success
  43. if not target:
  44. success = F.depend(success, sparse_opt(params, m, v, beta1_power, beta2_power, lr, beta1, beta2,
  45. eps, values, indices))
  46. else:
  47. op_gather = P.GatherV2()
  48. op_sqrt = P.Sqrt()
  49. scatter_add = P.ScatterAdd(use_locking)
  50. scatter_update = P.ScatterUpdate(use_locking)
  51. m_slice = op_gather(m, indices, 0)
  52. v_slice = op_gather(v, indices, 0)
  53. next_m = m_slice * beta1 + values * (1 - beta1)
  54. next_v = v_slice * beta2 + values * values * (1 - beta2)
  55. lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
  56. if use_nesterov:
  57. m_temp = beta1 * next_m + values * (1 - beta1)
  58. param_update = m_temp / (op_sqrt(next_v) + eps)
  59. else:
  60. param_update = next_m / (op_sqrt(next_v) + eps)
  61. success = F.depend(success, scatter_add(params, indices, - lr_t * param_update))
  62. success = F.depend(success, scatter_update(m, indices, next_m))
  63. success = F.depend(success, scatter_update(v, indices, next_v))
  64. return success
  65. @_lazy_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
  66. "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
  67. def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
  68. beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter):
  69. """Apply lazy adam optimizer to the weight parameter using Tensor."""
  70. success = True
  71. if ps_parameter:
  72. op_shape = P.Shape()
  73. success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
  74. (op_shape(params), op_shape(moment1), op_shape(moment2))), params))
  75. else:
  76. success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
  77. eps, gradient))
  78. return success
  79. def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
  80. """Check the type of inputs."""
  81. validator.check_value_type("beta1", beta1, [float], prim_name)
  82. validator.check_value_type("beta2", beta2, [float], prim_name)
  83. validator.check_value_type("eps", eps, [float], prim_name)
  84. validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
  85. validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
  86. validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
  87. validator.check_positive_float(eps, "eps", prim_name)
  88. validator.check_non_negative_float(weight_decay, "weight_decay", prim_name)
  89. class LazyAdam(Optimizer):
  90. r"""
  91. Updates gradients by Adaptive Moment Estimation (Adam) algorithm.
  92. The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
  93. The updating formulas are as follows,
  94. .. math::
  95. \begin{array}{ll} \\
  96. m = \beta_1 * m + (1 - \beta_1) * g \\
  97. v = \beta_2 * v + (1 - \beta_2) * g * g \\
  98. l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
  99. w = w - l * \frac{m}{\sqrt{v} + \epsilon}
  100. \end{array}
  101. :math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
  102. :math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent
  103. `beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
  104. `beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
  105. :math:`\epsilon` represents `eps`.
  106. Note:
  107. When separating parameter groups, the weight decay in each group will be applied on the parameters if the
  108. weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
  109. on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
  110. To improve parameter groups performance, the customized order of parameters can be supported.
  111. The sparse strategy is applied while the SparseGatherV2 operator being used for forward network.
  112. The sparse behavior, to be notice, is not equivalent to the
  113. original Adam algorithm, as only the current indices parames will be updated. The sparse feature is under
  114. continuous development. If the sparse strategy wants to be executed on the host, set the target to the CPU.
  115. Args:
  116. params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
  117. the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
  118. "lr" and "weight_decay" are the keys can be parsed.
  119. - params: Required. The value must be a list of `Parameter`.
  120. - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
  121. If not, the `learning_rate` in the API will be used.
  122. - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
  123. will be used. If not, the `weight_decay` in the API will be used.
  124. - order_params: Optional. If "order_params" in the keys, the value must be the order of parameters and
  125. the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
  126. in the value of 'order_params' must be in one of group parameters.
  127. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
  128. When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then
  129. the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
  130. use dynamic learning rate, the i-th learning rate will be calculated during the process of training
  131. according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
  132. dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
  133. equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
  134. Default: 1e-3.
  135. beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
  136. Default: 0.9.
  137. beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
  138. Default: 0.999.
  139. eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
  140. 1e-8.
  141. use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
  142. If true, updates of the var, m, and v tensors will be protected by a lock.
  143. If false, the result is unpredictable. Default: False.
  144. use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
  145. If true, update the gradients using NAG.
  146. If true, update the gradients without using NAG. Default: False.
  147. weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
  148. loss_scale (float): A floating point value for the loss scale. Should be equal to or greater than 1. Default:
  149. 1.0.
  150. Inputs:
  151. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
  152. Outputs:
  153. Tensor[bool], the value is True.
  154. Examples:
  155. >>> net = Net()
  156. >>> #1) All parameters use the same learning rate and weight decay
  157. >>> optim = nn.LazyAdam(params=net.trainable_params())
  158. >>>
  159. >>> #2) Use parameter groups and set different values
  160. >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
  161. >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
  162. >>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
  163. >>> {'params': no_conv_params, 'lr': 0.01},
  164. >>> {'order_params': net.trainable_params()}]
  165. >>> optim = nn.LazyAdam(group_params, learning_rate=0.1, weight_decay=0.0)
  166. >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
  167. >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
  168. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
  169. >>>
  170. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  171. >>> model = Model(net, loss_fn=loss, optimizer=optim)
  172. """
  173. def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
  174. use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
  175. super(LazyAdam, self).__init__(learning_rate, params, weight_decay, loss_scale)
  176. _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
  177. validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
  178. validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
  179. self.beta1 = Tensor(beta1, mstype.float32)
  180. self.beta2 = Tensor(beta2, mstype.float32)
  181. self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
  182. self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
  183. self.eps = Tensor(eps, mstype.float32)
  184. self.use_nesterov = use_nesterov
  185. self.use_locking = use_locking
  186. self._is_device = True
  187. self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
  188. self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
  189. self.hyper_map = C.HyperMap()
  190. self.opt = P.Adam(use_locking, use_nesterov)
  191. self.sparse_opt = P.FusedSparseLazyAdam(use_locking, use_nesterov)
  192. self.sparse_opt.add_prim_attr("primitive_target", "CPU")
  193. self._ps_pull = P.Pull()
  194. self._ps_push = P.Push("Adam", [0, 1, 2])
  195. self._ps_push.add_prim_attr("use_nesterov", use_nesterov)
  196. def construct(self, gradients):
  197. gradients = self.decay_weight(gradients)
  198. gradients = self.scale_grad(gradients)
  199. gradients = self._grad_sparse_indices_deduplicate(gradients)
  200. lr = self.get_lr()
  201. self.beta1_power = self.beta1_power * self.beta1
  202. self.beta2_power = self.beta2_power * self.beta2
  203. if self.is_group_lr:
  204. success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
  205. self.use_locking, self.use_nesterov, self._is_device,
  206. self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps),
  207. lr, gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters)
  208. else:
  209. success = self.map_(F.partial(_lazy_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
  210. self.use_locking, self.use_nesterov, self._is_device,
  211. self.beta1_power, self.beta2_power, self.beta1, self.beta2, self.eps, lr),
  212. gradients, self.parameters, self.moment1, self.moment2, self.ps_parameters)
  213. return success
  214. @Optimizer.target.setter
  215. def target(self, value):
  216. """If the input value is set to "CPU", the parameters will be updated on the host using the Fused
  217. optimizer operation."""
  218. if value not in ('CPU', 'Ascend'):
  219. raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))
  220. self._is_device = (value != 'CPU')
  221. self._target = value