You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

adam.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """adam"""
  16. import numpy as np
  17. from mindspore.common import dtype as mstype
  18. from mindspore.common.initializer import initializer
  19. from mindspore.ops import operations as P
  20. from mindspore.ops import composite as C
  21. from mindspore.ops import functional as F
  22. from mindspore.common.parameter import Parameter
  23. from mindspore.common.tensor import Tensor
  24. from mindspore._checkparam import Validator as validator
  25. from mindspore._checkparam import Rel
  26. from .optimizer import Optimizer
  27. _learning_rate_update_func = ['linear', 'cos', 'sin']
  28. adam_opt = C.MultitypeFuncGraph("adam_opt")
  29. @adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
  30. def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, gradient, decay_flag):
  31. """
  32. Update parameters.
  33. Args:
  34. beta1 (Tensor): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0).
  35. beta2 (Tensor): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0).
  36. eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
  37. lr (Tensor): Learning rate.
  38. weight_decay_tensor (Tensor): Weight decay. Should be equal to or greater than 0.
  39. param (Tensor): Parameters.
  40. m (Tensor): m value of parameters.
  41. v (Tensor): v value of parameters.
  42. gradient (Tensor): Gradient of parameters.
  43. Returns:
  44. Tensor, the new value of v after updating.
  45. """
  46. op_mul = P.Mul()
  47. op_square = P.Square()
  48. op_sqrt = P.Sqrt()
  49. op_cast = P.Cast()
  50. op_reshape = P.Reshape()
  51. op_shape = P.Shape()
  52. param = op_cast(param, mstype.float32)
  53. m = op_cast(m, mstype.float32)
  54. v = op_cast(v, mstype.float32)
  55. gradient = op_cast(gradient, mstype.float32)
  56. next_m = op_mul(beta1, m) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient)
  57. next_v = op_mul(beta2, v) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient))
  58. update = next_m / (op_sqrt(next_v) + eps)
  59. if decay_flag:
  60. update = update + op_mul(weight_decay_tensor, param)
  61. update_with_lr = op_mul(lr, update)
  62. next_param = param - op_reshape(update_with_lr, op_shape(param))
  63. next_v = F.depend(next_v, F.assign(param, next_param))
  64. next_v = F.depend(next_v, F.assign(m, next_m))
  65. next_v = F.depend(next_v, F.assign(v, next_v))
  66. return next_v
  67. def _check_param_value(beta1, beta2, eps, weight_decay, prim_name):
  68. """Check the type of inputs."""
  69. validator.check_value_type("beta1", beta1, [float], prim_name)
  70. validator.check_value_type("beta2", beta2, [float], prim_name)
  71. validator.check_value_type("eps", eps, [float], prim_name)
  72. validator.check_value_type("weight_dacay", weight_decay, [float], prim_name)
  73. validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
  74. validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name)
  75. validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name)
  76. validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name)
  77. def _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, prim_name):
  78. """Check the type of inputs."""
  79. validator.check_float_positive('learning_rate', learning_rate, prim_name)
  80. validator.check_float_legal_value('learning_rate', learning_rate, prim_name)
  81. validator.check_float_positive('end_learning_rate', end_learning_rate, prim_name)
  82. validator.check_float_legal_value('end_learning_rate', end_learning_rate, prim_name)
  83. validator.check_float_positive('power', power, prim_name)
  84. validator.check_float_legal_value('power', power, prim_name)
  85. validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
  86. @adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
  87. "Tensor")
  88. def _run_opt_with_one_number(opt, lr, beta1_power, beta2_power, beta1, beta2, eps, gradient, params, moment1,
  89. moment2):
  90. """Apply adam optimizer to the weight parameter using Tensor."""
  91. success = True
  92. success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
  93. eps, gradient))
  94. return success
  95. class Adam(Optimizer):
  96. r"""
  97. Updates gradients by Adaptive Moment Estimation (Adam) algorithm.
  98. The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
  99. The updating formulas are as follows,
  100. .. math::
  101. \begin{array}{ll} \\
  102. m = \beta_1 * m + (1 - \beta_1) * g \\
  103. v = \beta_2 * v + (1 - \beta_2) * g * g \\
  104. l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
  105. w = w - l * \frac{m}{\sqrt{v} + \epsilon}
  106. \end{array}
  107. :math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
  108. :math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent
  109. `beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
  110. `beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
  111. :math:`\epsilon` represents `eps`.
  112. Args:
  113. params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
  114. should be class mindspore.Parameter.
  115. learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
  116. Iterable or a Tensor and the dims of the Tensor is 1,
  117. use dynamic learning rate, then the i-th step will
  118. take the i-th value as the learning rate.
  119. When the learning_rate is float or learning_rate is a Tensor
  120. but the dims of the Tensor is 0, use fixed learning rate.
  121. Other cases are not supported. Default: 1e-3.
  122. beta1 (float): The exponential decay rate for the 1st moment estimates. Should be in range (0.0, 1.0).
  123. beta2 (float): The exponential decay rate for the 2nd moment estimates. Should be in range (0.0, 1.0).
  124. eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0.
  125. use_locking (bool): Whether to enable a lock to protect updating variable tensors.
  126. If True, updating of the var, m, and v tensors will be protected by a lock.
  127. If False, the result is unpredictable. Default: False.
  128. use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
  129. If True, updates the gradients using NAG.
  130. If False, updates the gradients without using NAG. Default: False.
  131. weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
  132. loss_scale (float): A floating point value for the loss scale. Default: 1.0.
  133. Should be equal to or greater than 1.
  134. decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
  135. lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
  136. Inputs:
  137. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
  138. Outputs:
  139. Tensor[bool], the value is True.
  140. Examples:
  141. >>> net = Net()
  142. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  143. >>> optim = nn.Adam(params=net.trainable_params())
  144. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
  145. """
  146. def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
  147. use_nesterov=False, weight_decay=0.0, loss_scale=1.0,
  148. decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
  149. super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter)
  150. _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
  151. validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
  152. validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
  153. validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name)
  154. validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name)
  155. self.beta1 = Tensor(beta1, mstype.float32)
  156. self.beta2 = Tensor(beta2, mstype.float32)
  157. self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
  158. self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
  159. self.eps = eps
  160. self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
  161. self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
  162. self.hyper_map = C.HyperMap()
  163. self.opt = P.Adam(use_locking, use_nesterov)
  164. self.pow = P.Pow()
  165. self.sqrt = P.Sqrt()
  166. self.one = Tensor(np.array([1.0]).astype(np.float32))
  167. self.realdiv = P.RealDiv()
  168. def construct(self, gradients):
  169. params = self.parameters
  170. moment1 = self.moment1
  171. moment2 = self.moment2
  172. gradients = self.decay_weight(gradients)
  173. gradients = self.scale_grad(gradients)
  174. lr = self.get_lr()
  175. beta1_power = self.beta1_power * self.beta1
  176. self.beta1_power = beta1_power
  177. beta2_power = self.beta2_power * self.beta2
  178. self.beta2_power = beta2_power
  179. success = self.hyper_map(F.partial(adam_opt, self.opt, lr, beta1_power, beta2_power, self.beta1,
  180. self.beta2, self.eps),
  181. gradients, params, moment1, moment2)
  182. return success
  183. class AdamWeightDecay(Optimizer):
  184. """
  185. Implements Adam algorithm weight decay fix.
  186. Args:
  187. params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
  188. should be class mindspore.Parameter.
  189. learning_rate (Union[float, Tensor, Iterable]): A value for the learning rate. When the learning_rate is
  190. Iterable or a Tensor and the dims of the Tensor is 1,
  191. use dynamic learning rate, then the i-th step will
  192. take the i-th value as the learning rate.
  193. When the learning_rate is float or learning_rate is a Tensor
  194. but the dims of the Tensor is 0, use fixed learning rate.
  195. Other cases are not supported. Default: 1e-3.
  196. beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
  197. Should be in range (0.0, 1.0).
  198. beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
  199. Should be in range (0.0, 1.0).
  200. eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
  201. Should be greater than 0.
  202. weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
  203. decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
  204. lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
  205. Inputs:
  206. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
  207. Outputs:
  208. tuple[Parameter], the updated velocity value, the shape is the same as `params`.
  209. Examples:
  210. >>> net = Net()
  211. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  212. >>> optim = nn.AdamWeightDecay(params=net.trainable_params())
  213. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
  214. """
  215. def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0,
  216. decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
  217. super(AdamWeightDecay, self).__init__(learning_rate, params)
  218. _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
  219. self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
  220. self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
  221. self.eps = Tensor(np.array([eps]).astype(np.float32))
  222. self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
  223. self.params = self.parameters
  224. self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
  225. self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
  226. self.decay_flag = tuple(decay_filter(x) for x in self.params)
  227. self.hyper_map = C.HyperMap()
  228. def construct(self, gradients):
  229. lr = self.get_lr()
  230. updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr,
  231. self.weight_decay_tensor),
  232. self.params, self.moments1, self.moments2, gradients, self.decay_flag)
  233. return updated_velocity
  234. class AdamWeightDecayDynamicLR(Optimizer):
  235. """
  236. Adam Weight Decay Dynamic Learning Rate (LR).
  237. Args:
  238. params (list[Parameter]): A list of parameter, which will be updated. The element in `params`
  239. should be class mindspore.Parameter.
  240. decay_steps (int): The steps of the decay.
  241. learning_rate (float): A floating point value for the learning rate. Default: 0.001.
  242. end_learning_rate (float): A floating point value for the end learning rate. Default: 0.0001.
  243. power (float): Power. Default: 10.0.
  244. beta1 (float): The exponential decay rate for the 1st moment estimates. Default: 0.9.
  245. Should be in range (0.0, 1.0).
  246. beta2 (float): The exponential decay rate for the 2nd moment estimates. Default: 0.999.
  247. Should be in range (0.0, 1.0).
  248. eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
  249. Should be greater than 0.
  250. weight_decay (float): Weight decay (L2 penalty). Default: 0.0.
  251. decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default:
  252. lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.
  253. Inputs:
  254. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
  255. Outputs:
  256. tuple[Parameter], the updated velocity value, the shape is the same as `params`.
  257. Examples:
  258. >>> net = Net()
  259. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  260. >>> optim = nn.AdamWeightDecayDynamicLR(params=net.trainable_params(), decay_steps=10)
  261. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
  262. """
  263. def __init__(self,
  264. params,
  265. decay_steps,
  266. learning_rate=0.001,
  267. end_learning_rate=0.0001,
  268. power=10.0,
  269. beta1=0.9,
  270. beta2=0.999,
  271. eps=1e-6,
  272. weight_decay=0.0,
  273. decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
  274. super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params)
  275. _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name)
  276. _check_learning_rate_value(learning_rate, end_learning_rate, decay_steps, power, self.cls_name)
  277. # turn them to scalar when me support scalar/tensor mix operations
  278. self.global_step = Parameter(initializer(0, [1]), name="global_step")
  279. self.decay_steps = Tensor(np.array([decay_steps]).astype(np.float32))
  280. self.end_learning_rate = Tensor(np.array([end_learning_rate]).astype(np.float32))
  281. self.diff_learning_rate = Tensor(np.array([learning_rate - end_learning_rate]).astype(np.float32))
  282. self.power = power
  283. self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
  284. self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
  285. self.eps = Tensor(np.array([eps]).astype(np.float32))
  286. self.weight_decay_tensor = Tensor(np.array([weight_decay]).astype(np.float32))
  287. self.params = self.parameters
  288. self.moments1 = self.params.clone(prefix="adam_m", init='zeros')
  289. self.moments2 = self.params.clone(prefix="adam_v", init='zeros')
  290. self.decay_flag = tuple(decay_filter(x) for x in self.params)
  291. self.hyper_map = C.HyperMap()
  292. self.min = P.Minimum()
  293. self.pow = P.Pow()
  294. self.one = Tensor(np.array([1.0]).astype(np.float32))
  295. def construct(self, gradients):
  296. step = self.min(self.global_step, self.decay_steps)
  297. p = step / self.decay_steps
  298. lr = self.diff_learning_rate * self.pow(self.one - p, self.power) + self.end_learning_rate
  299. updated_velocity = self.hyper_map(F.partial(adam_opt, self.beta1, self.beta2, self.eps, lr,
  300. self.weight_decay_tensor),
  301. self.params, self.moments1, self.moments2, gradients, self.decay_flag)
  302. added_global_step = self.global_step + self.one
  303. F.control_depend(lr, added_global_step)
  304. self.global_step = added_global_step
  305. return updated_velocity