You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

adam.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """AdamWeightDecayForBert, a customized Adam for bert. Input: gradient, overflow flag."""
  16. import numpy as np
  17. from mindspore.common import dtype as mstype
  18. from mindspore.ops import operations as P
  19. from mindspore.ops import composite as C
  20. from mindspore.ops import functional as F
  21. from mindspore.common.tensor import Tensor
  22. from mindspore._checkparam import Validator as validator
  23. from mindspore._checkparam import Rel
  24. from mindspore.nn.optim.optimizer import Optimizer
  25. _adam_opt = C.MultitypeFuncGraph("adam_opt")
  26. _scaler_one = Tensor(1, mstype.int32)
  27. _scaler_ten = Tensor(10, mstype.float32)
  28. @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
  29. "Tensor", "Bool", "Bool")
  30. def _update_run_op(beta1, beta2, eps, lr, overflow, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
  31. """
  32. Update parameters.
  33. Args:
  34. beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
  35. beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
  36. eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
  37. lr (Tensor): Learning rate.
  38. overflow (Tensor): Whether overflow occurs.
  39. weight_decay (Number): Weight decay. Should be equal to or greater than 0.
  40. param (Tensor): Parameters.
  41. m (Tensor): m value of parameters.
  42. v (Tensor): v value of parameters.
  43. gradient (Tensor): Gradient of parameters.
  44. decay_flag (bool): Applies weight decay or not.
  45. optim_filter (bool): Applies parameter update or not.
  46. Returns:
  47. Tensor, the new value of v after updating.
  48. """
  49. if optim_filter:
  50. op_mul = P.Mul()
  51. op_square = P.Square()
  52. op_sqrt = P.Sqrt()
  53. op_cast = P.Cast()
  54. op_reshape = P.Reshape()
  55. op_shape = P.Shape()
  56. op_select = P.Select()
  57. param_fp32 = op_cast(param, mstype.float32)
  58. m_fp32 = op_cast(m, mstype.float32)
  59. v_fp32 = op_cast(v, mstype.float32)
  60. gradient_fp32 = op_cast(gradient, mstype.float32)
  61. cond = op_cast(F.fill(mstype.int32, op_shape(m_fp32), 1) * op_reshape(overflow, (())), mstype.bool_)
  62. next_m = op_mul(beta1, m_fp32) + op_select(cond, m_fp32,\
  63. op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32))
  64. next_v = op_mul(beta2, v_fp32) + op_select(cond, v_fp32,\
  65. op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient_fp32)))
  66. update = next_m / (eps + op_sqrt(next_v))
  67. if decay_flag:
  68. update = op_mul(weight_decay, param_fp32) + update
  69. update_with_lr = op_mul(lr, update)
  70. zeros = F.fill(mstype.float32, op_shape(param_fp32), 0)
  71. next_param = param_fp32 - op_select(cond, zeros, op_reshape(update_with_lr, op_shape(param_fp32)))
  72. next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
  73. next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
  74. next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
  75. return op_cast(next_param, F.dtype(param))
  76. return gradient
  77. @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
  78. "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
  79. def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
  80. beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable):
  81. """Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
  82. success = True
  83. indices = gradient.indices
  84. values = gradient.values
  85. if ps_parameter and not cache_enable:
  86. op_shape = P.Shape()
  87. shapes = (op_shape(param), op_shape(m), op_shape(v),
  88. op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
  89. op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
  90. success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
  91. eps, values, indices), shapes), param))
  92. return success
  93. if not target:
  94. success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2,
  95. eps, values, indices))
  96. else:
  97. op_mul = P.Mul()
  98. op_square = P.Square()
  99. op_sqrt = P.Sqrt()
  100. scatter_add = P.ScatterAdd(use_locking)
  101. assign_m = F.assign(m, op_mul(beta1, m))
  102. assign_v = F.assign(v, op_mul(beta2, v))
  103. grad_indices = gradient.indices
  104. grad_value = gradient.values
  105. next_m = scatter_add(m,
  106. grad_indices,
  107. op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
  108. next_v = scatter_add(v,
  109. grad_indices,
  110. op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value)))
  111. if use_nesterov:
  112. m_temp = next_m * _scaler_ten
  113. assign_m_nesterov = F.assign(m, op_mul(beta1, next_m))
  114. div_value = scatter_add(m,
  115. op_mul(grad_indices, _scaler_one),
  116. op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
  117. param_update = div_value / (op_sqrt(next_v) + eps)
  118. m_recover = F.assign(m, m_temp / _scaler_ten)
  119. F.control_depend(m_temp, assign_m_nesterov)
  120. F.control_depend(assign_m_nesterov, div_value)
  121. F.control_depend(param_update, m_recover)
  122. else:
  123. param_update = next_m / (op_sqrt(next_v) + eps)
  124. lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
  125. next_param = param - lr_t * param_update
  126. F.control_depend(assign_m, next_m)
  127. F.control_depend(assign_v, next_v)
  128. success = F.depend(success, F.assign(param, next_param))
  129. success = F.depend(success, F.assign(m, next_m))
  130. success = F.depend(success, F.assign(v, next_v))
  131. return success
  132. @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
  133. "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
  134. def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target,
  135. beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param,
  136. moment1, moment2, ps_parameter, cache_enable):
  137. """Apply adam optimizer to the weight parameter using Tensor."""
  138. success = True
  139. if ps_parameter and not cache_enable:
  140. op_shape = P.Shape()
  141. success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
  142. (op_shape(param), op_shape(moment1), op_shape(moment2))), param))
  143. else:
  144. success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
  145. eps, gradient))
  146. return success
  147. @_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
  148. "Tensor", "Tensor")
  149. def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
  150. """Apply AdamOffload optimizer to the weight parameter using Tensor."""
  151. success = True
  152. delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)
  153. success = F.depend(success, F.assign_add(param, delat_param))
  154. return success
  155. def _check_param_value(beta1, beta2, eps, prim_name):
  156. """Check the type of inputs."""
  157. validator.check_value_type("beta1", beta1, [float], prim_name)
  158. validator.check_value_type("beta2", beta2, [float], prim_name)
  159. validator.check_value_type("eps", eps, [float], prim_name)
  160. validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
  161. validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
  162. validator.check_positive_float(eps, "eps", prim_name)
  163. class AdamWeightDecayForBert(Optimizer):
  164. """
  165. Implements the Adam algorithm to fix the weight decay.
  166. Note:
  167. When separating parameter groups, the weight decay in each group will be applied on the parameters if the
  168. weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
  169. on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
  170. To improve parameter groups performance, the customized order of parameters can be supported.
  171. Args:
  172. params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
  173. the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
  174. "lr", "weight_decay" and "order_params" are the keys can be parsed.
  175. - params: Required. The value must be a list of `Parameter`.
  176. - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
  177. If not, the `learning_rate` in the API will be used.
  178. - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
  179. will be used. If not, the `weight_decay` in the API will be used.
  180. - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
  181. the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
  182. which in the 'order_params' must be in one of group parameters.
  183. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
  184. When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
  185. the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
  186. use dynamic learning rate, the i-th learning rate will be calculated during the process of training
  187. according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
  188. dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
  189. equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
  190. Default: 1e-3.
  191. beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
  192. Should be in range (0.0, 1.0).
  193. beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
  194. Should be in range (0.0, 1.0).
  195. eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
  196. Should be greater than 0.
  197. weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
  198. Inputs:
  199. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
  200. - **overflow** (tuple[Tensor]) - The overflow flag in dynamiclossscale.
  201. Outputs:
  202. tuple[bool], all elements are True.
  203. Supported Platforms:
  204. ``Ascend`` ``GPU``
  205. Examples:
  206. >>> net = Net()
  207. >>> #1) All parameters use the same learning rate and weight decay
  208. >>> optim = nn.AdamWeightDecay(params=net.trainable_params())
  209. >>>
  210. >>> #2) Use parameter groups and set different values
  211. >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
  212. >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
  213. >>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
  214. ... {'params': no_conv_params, 'lr': 0.01},
  215. ... {'order_params': net.trainable_params()}]
  216. >>> optim = nn.AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0)
  217. >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
  218. >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
  219. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
  220. >>>
  221. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  222. >>> model = Model(net, loss_fn=loss, optimizer=optim)
  223. """
  224. def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
  225. super(AdamWeightDecayForBert, self).__init__(learning_rate, params, weight_decay)
  226. _check_param_value(beta1, beta2, eps, self.cls_name)
  227. self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
  228. self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
  229. self.eps = Tensor(np.array([eps]).astype(np.float32))
  230. self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
  231. self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
  232. self.hyper_map = C.HyperMap()
  233. self.op_select = P.Select()
  234. self.op_cast = P.Cast()
  235. self.op_reshape = P.Reshape()
  236. self.op_shape = P.Shape()
  237. def construct(self, gradients, overflow):
  238. """AdamWeightDecayForBert"""
  239. lr = self.get_lr()
  240. cond = self.op_cast(F.fill(mstype.int32, self.op_shape(self.beta1), 1) *\
  241. self.op_reshape(overflow, (())), mstype.bool_)
  242. beta1 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta1)
  243. beta2 = self.op_select(cond, self.op_cast(F.tuple_to_array((1.0,)), mstype.float32), self.beta2)
  244. if self.is_group:
  245. if self.is_group_lr:
  246. optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
  247. lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
  248. gradients, self.decay_flags, self.optim_filter)
  249. else:
  250. optim_result = self.hyper_map(F.partial(_adam_opt, beta1, beta2, self.eps, lr, overflow),
  251. self.weight_decay, self.parameters, self.moments1, self.moments2,
  252. gradients, self.decay_flags, self.optim_filter)
  253. else:
  254. optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
  255. self.parameters, self.moments1, self.moments2,
  256. gradients, self.decay_flags, self.optim_filter)
  257. if self.use_parallel:
  258. self.broadcast_params(optim_result)
  259. return optim_result