You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lamb.py 16 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """lamb"""
  16. import numpy as np
  17. from mindspore import context
  18. from mindspore.common import dtype as mstype
  19. from mindspore.common.initializer import initializer
  20. from mindspore.ops import operations as P
  21. from mindspore.ops import composite as C
  22. from mindspore.ops import functional as F
  23. from mindspore.common.parameter import Parameter
  24. from mindspore.common.tensor import Tensor
  25. from mindspore._checkparam import Validator as validator
  26. from mindspore._checkparam import Rel
  27. from .optimizer import Optimizer
  28. from .. import layer
  29. num_one = Tensor(np.ones([1]), mstype.float32)
  30. _lamb_opt = C.MultitypeFuncGraph("lamb_opt")
  31. @_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
  32. "Tensor", "Bool", "Bool")
  33. def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
  34. """
  35. Update parameters.
  36. Args:
  37. beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
  38. beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
  39. eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
  40. lr (Tensor): Learning rate.
  41. weight_decay (Number): Weight decay. Should be equal to or greater than 0.
  42. global_step (Tensor): Global step.
  43. param (Tensor): Parameters.
  44. m (Tensor): m value of parameters.
  45. v (Tensor): v value of parameters.
  46. gradient (Tensor): Gradient of parameters.
  47. decay_flag (bool): Specifies whether param update with weight decay.
  48. optim_filter(bool): Applies parameter update or not.
  49. Returns:
  50. Tensor, the new value of v after updating.
  51. """
  52. if optim_filter:
  53. op_mul = P.Mul()
  54. op_sqrt = P.Sqrt()
  55. op_rsqrt = P.Rsqrt()
  56. op_square = P.Square()
  57. op_cast = P.Cast()
  58. op_reshape = P.Reshape()
  59. op_shape = P.Shape()
  60. op_pow = P.Pow()
  61. op_norm = layer.Norm()
  62. op_select = P.Select()
  63. op_greater = P.Greater()
  64. op_fill = P.Fill()
  65. op_dtype = P.DType()
  66. param_fp32 = op_cast(param, mstype.float32)
  67. m_fp32 = op_cast(m, mstype.float32)
  68. v_fp32 = op_cast(v, mstype.float32)
  69. gradient_fp32 = op_cast(gradient, mstype.float32)
  70. next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32)
  71. next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))
  72. next_mm = next_m / (op_cast(num_one, mstype.float32)
  73. - op_pow(beta1, op_cast(global_step + num_one, mstype.float32)))
  74. next_vv = next_v / (op_cast(num_one, mstype.float32) -
  75. op_pow(beta2, op_cast(global_step + num_one, mstype.float32)))
  76. w_norm = op_norm(param_fp32)
  77. g_norm = op_norm(gradient_fp32)
  78. g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay * param_fp32)
  79. zeros = F.zeros_like(w_norm)
  80. ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
  81. trust_ratio = op_select(
  82. op_greater(w_norm, zeros),
  83. op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
  84. ones)
  85. tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
  86. trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
  87. update = next_mm / (op_sqrt(next_vv) + eps)
  88. if decay_flag:
  89. update = update + op_mul(weight_decay, param_fp32)
  90. update_with_lr = op_mul(op_mul(trust_ratio, lr), update)
  91. next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
  92. next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
  93. next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
  94. next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
  95. return op_cast(next_param, F.dtype(param))
  96. return gradient
  97. _lamb_opt_ascend = C.MultitypeFuncGraph("lamb_opt_ascend")
  98. @_lamb_opt_ascend.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
  99. "Tensor", "Bool", "Bool")
  100. def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag,
  101. optim_filter):
  102. """
  103. Update parameters function when device target is ascend.
  104. Args:
  105. beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
  106. beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
  107. eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
  108. lr (Tensor): Learning rate.
  109. weight_decay (Number): Weight decay. Should be equal to or greater than 0.
  110. global_step (Tensor): Global step.
  111. param (Tensor): Parameters.
  112. m (Tensor): m value of parameters.
  113. v (Tensor): v value of parameters.
  114. gradient (Tensor): Gradient of parameters.
  115. decay_flag (bool): Specifies whether param update with weight decay.
  116. optim_filter(bool): Applies parameter update or not.
  117. Returns:
  118. Tensor, the new value of v after updating.
  119. """
  120. if optim_filter:
  121. op_cast = P.Cast()
  122. op_norm = layer.Norm()
  123. op_lamb_apply_optimizer_assign = P.LambApplyOptimizerAssign()
  124. op_lamb_apply_weight_assign = P.LambApplyWeightAssign()
  125. param_fp32 = op_cast(param, mstype.float32)
  126. gradient_fp32 = op_cast(gradient, mstype.float32)
  127. new_global_step = op_cast(global_step + num_one, mstype.float32)
  128. weight_decay_flag = op_cast(decay_flag, mstype.float32)
  129. update, _, _ = op_lamb_apply_optimizer_assign(gradient_fp32, v, m, param_fp32,
  130. beta1, 1.0 - beta1, beta2, 1.0 - beta2, eps,
  131. new_global_step, weight_decay_flag, weight_decay)
  132. w_norm = op_norm(param_fp32)
  133. g_norm = op_norm(update)
  134. update = F.depend(update, op_lamb_apply_weight_assign(w_norm, g_norm, lr, update, param))
  135. return update
  136. return gradient
  137. def _check_param_value(beta1, beta2, eps, prim_name):
  138. validator.check_value_type("beta1", beta1, [float], prim_name)
  139. validator.check_value_type("beta2", beta2, [float], prim_name)
  140. validator.check_value_type("eps", eps, [float], prim_name)
  141. validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
  142. validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
  143. validator.check_positive_float(eps, "eps", prim_name)
  144. class Lamb(Optimizer):
  145. """
  146. Lamb Dynamic Learning Rate.
  147. LAMB is an optimization algorithm employing a layerwise adaptive large batch
  148. optimization technique. Refer to the paper `LARGE BATCH OPTIMIZATION FOR DEEP LEARNING: TRAINING BERT IN 76
  149. MINUTES <https://arxiv.org/abs/1904.00962>`_.
  150. Note:
  151. When separating parameter groups, the weight decay in each group will be applied on the parameters if the
  152. weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
  153. on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
  154. To improve parameter groups performance, the customized order of parameters can be supported.
  155. Args:
  156. params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
  157. the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
  158. "lr", "weight_decay" and "order_params" are the keys can be parsed.
  159. - params: Required. The value must be a list of `Parameter`.
  160. - lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used.
  161. If not, the `learning_rate` in the API will be used.
  162. - weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay
  163. will be used. If not, the `weight_decay` in the API will be used.
  164. - order_params: Optional. If "order_params" in the keys, the value must be the order of parameters and
  165. the order will be followed in optimizer. There are no other keys in the `dict` and the parameters which
  166. in the value of 'order_params' must be in one of group parameters.
  167. - grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
  168. If not, the `grad_centralization` is False by default. This parameter only works on the convolution layer.
  169. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
  170. When the learning_rate is an Iterable or a Tensor in a 1D dimension, use dynamic learning rate, then
  171. the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
  172. use dynamic learning rate, the i-th learning rate will be calculated during the process of training
  173. according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
  174. dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
  175. equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
  176. beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
  177. Should be in range (0.0, 1.0).
  178. beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
  179. Should be in range (0.0, 1.0).
  180. eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
  181. Should be greater than 0.
  182. weight_decay (float): Weight decay (L2 penalty). Default: 0.0. Should be equal to or greater than 0.
  183. Inputs:
  184. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
  185. Outputs:
  186. tuple[bool], all elements are True.
  187. Raises:
  188. TypeError: If `learning_rate` is not one of int, float, Tensor, Iterable, LearningRateSchedule.
  189. TypeError: If element of `parameters` is neither Parameter nor dict.
  190. TypeError: If `beta1`, `beta2` or `eps` is not a float.
  191. TypeError: If `weight_decay` is neither float nor int.
  192. ValueError: If `eps` is less than or equal to 0.
  193. ValueError: If `beta1`, `beta2` is not in range (0.0, 1.0).
  194. ValueError: If `weight_decay` is less than 0.
  195. Supported Platforms:
  196. ``Ascend`` ``GPU``
  197. Examples:
  198. >>> net = Net()
  199. >>> #1) All parameters use the same learning rate and weight decay
  200. >>> optim = nn.Lamb(params=net.trainable_params(), learning_rate=0.1)
  201. >>>
  202. >>> #2) Use parameter groups and set different values
  203. >>> poly_decay_lr = learning_rate_schedule.PolynomialDecayLR(learning_rate=0.1, end_learning_rate=0.01,
  204. ... decay_steps=4, power = 0.5)
  205. >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
  206. >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
  207. >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
  208. ... {'params': no_conv_params, 'lr': poly_decay_lr},
  209. ... {'order_params': net.trainable_params(0.01)}]
  210. >>> optim = nn.Lamb(group_params, learning_rate=0.1, weight_decay=0.0)
  211. >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
  212. >>> # centralization of True.
  213. >>> # The no_conv_params's parameters will use dynamic learning rate of poly decay learning rate and default
  214. >>> # weight decay of 0.0 and grad centralization of False.
  215. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
  216. >>>
  217. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  218. >>> model = Model(net, loss_fn=loss, optimizer=optim)
  219. """
  220. def __init__(self, params, learning_rate, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
  221. super(Lamb, self).__init__(learning_rate, params, weight_decay)
  222. _check_param_value(beta1, beta2, eps, self.cls_name)
  223. # turn them to scalar when me support scalar/tensor mix operations
  224. self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
  225. self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
  226. self.eps = Tensor(np.array([eps]).astype(np.float32))
  227. self.params = self.parameters
  228. self.moments1 = self.params.clone(prefix="lamb_m", init='zeros')
  229. self.moments2 = self.params.clone(prefix="lamb_v", init='zeros')
  230. if not self.dynamic_lr:
  231. self.global_step = Parameter(initializer(0, [1]), name='global_step')
  232. self.assignadd = P.AssignAdd()
  233. self.hyper_map = C.HyperMap()
  234. self.device_ascend = context.get_context("device_target") == "Ascend"
  235. def construct(self, gradients):
  236. lr = self.get_lr()
  237. lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt
  238. gradients = self.gradients_centralization(gradients)
  239. if self.is_group:
  240. if self.is_group_lr:
  241. optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
  242. self.global_step),
  243. lr, self.weight_decay, self.params, self.moments1, self.moments2,
  244. gradients, self.decay_flags, self.optim_filter)
  245. else:
  246. optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
  247. self.global_step, lr),
  248. self.weight_decay, self.params, self.moments1, self.moments2,
  249. gradients, self.decay_flags, self.optim_filter)
  250. else:
  251. optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps,
  252. self.global_step, lr, self.weight_decay),
  253. self.params, self.moments1, self.moments2, gradients,
  254. self.decay_flags, self.optim_filter)
  255. if self.use_parallel:
  256. optim_result = F.depend(optim_result, self.broadcast_params(optim_result))
  257. if not self.dynamic_lr:
  258. optim_result = F.depend(optim_result, self.assignadd(self.global_step, 1))
  259. return optim_result