You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

adam.py 34 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """adam"""
  16. import numpy as np
  17. from mindspore.common import dtype as mstype
  18. from mindspore.common.initializer import initializer
  19. from mindspore.ops import operations as P
  20. from mindspore.ops import composite as C
  21. from mindspore.ops import functional as F
  22. from mindspore.common.parameter import Parameter
  23. from mindspore.common.tensor import Tensor
  24. from mindspore._checkparam import Validator as validator
  25. from mindspore._checkparam import Rel
  26. from .optimizer import Optimizer
  27. _adam_opt = C.MultitypeFuncGraph("adam_opt")
  28. _scaler_one = Tensor(1, mstype.int32)
  29. _scaler_ten = Tensor(10, mstype.float32)
  30. @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
  31. "Tensor", "Bool", "Bool")
  32. def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
  33. """
  34. Update parameters.
  35. Args:
  36. beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
  37. beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
  38. eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
  39. lr (Tensor): Learning rate.
  40. weight_decay (Number): Weight decay. Should be equal to or greater than 0.
  41. param (Tensor): Parameters.
  42. m (Tensor): m value of parameters.
  43. v (Tensor): v value of parameters.
  44. gradient (Tensor): Gradient of parameters.
  45. decay_flag (bool): Applies weight decay or not.
  46. optim_filter (bool): Applies parameter update or not.
  47. Returns:
  48. Tensor, the new value of v after updating.
  49. """
  50. if optim_filter:
  51. op_mul = P.Mul()
  52. op_square = P.Square()
  53. op_sqrt = P.Sqrt()
  54. op_cast = P.Cast()
  55. op_reshape = P.Reshape()
  56. op_shape = P.Shape()
  57. param_fp32 = op_cast(param, mstype.float32)
  58. m_fp32 = op_cast(m, mstype.float32)
  59. v_fp32 = op_cast(v, mstype.float32)
  60. gradient_fp32 = op_cast(gradient, mstype.float32)
  61. next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
  62. - beta1, gradient_fp32)
  63. next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
  64. - beta2, op_square(gradient_fp32))
  65. update = next_m / (eps + op_sqrt(next_v))
  66. if decay_flag:
  67. update = op_mul(weight_decay, param_fp32) + update
  68. update_with_lr = op_mul(lr, update)
  69. next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
  70. next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
  71. next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
  72. next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
  73. return op_cast(next_param, F.dtype(param))
  74. return gradient
  75. @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
  76. "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
  77. def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
  78. beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable):
  79. """Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
  80. success = True
  81. indices = gradient.indices
  82. values = gradient.values
  83. if ps_parameter and not cache_enable:
  84. op_shape = P.Shape()
  85. shapes = (op_shape(param), op_shape(m), op_shape(v),
  86. op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
  87. op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
  88. success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
  89. eps, values, indices), shapes), param))
  90. return success
  91. if not target:
  92. success = F.depend(success, sparse_opt(param, m, v, beta1_power, beta2_power, lr, beta1, beta2,
  93. eps, values, indices))
  94. else:
  95. op_mul = P.Mul()
  96. op_square = P.Square()
  97. op_sqrt = P.Sqrt()
  98. scatter_add = P.ScatterAdd(use_locking)
  99. success = F.depend(success, F.assign(m, op_mul(beta1, m)))
  100. success = F.depend(success, F.assign(v, op_mul(beta2, v)))
  101. grad_indices = gradient.indices
  102. grad_value = gradient.values
  103. next_m = scatter_add(m,
  104. grad_indices,
  105. op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
  106. next_v = scatter_add(v,
  107. grad_indices,
  108. op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value)))
  109. if use_nesterov:
  110. m_temp = next_m * _scaler_ten
  111. F.assign(m, op_mul(beta1, next_m))
  112. div_value = scatter_add(m,
  113. op_mul(grad_indices, _scaler_one),
  114. op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
  115. param_update = div_value / (op_sqrt(next_v) + eps)
  116. F.assign(m, m_temp / _scaler_ten)
  117. else:
  118. param_update = next_m / (op_sqrt(next_v) + eps)
  119. lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
  120. next_param = param - lr_t * param_update
  121. success = F.depend(success, F.assign(param, next_param))
  122. success = F.depend(success, F.assign(m, next_m))
  123. success = F.depend(success, F.assign(v, next_v))
  124. return success
  125. @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
  126. "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool")
  127. def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target,
  128. beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param,
  129. moment1, moment2, ps_parameter, cache_enable):
  130. """Apply adam optimizer to the weight parameter using Tensor."""
  131. success = True
  132. if ps_parameter and not cache_enable:
  133. op_shape = P.Shape()
  134. success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
  135. (op_shape(param), op_shape(moment1), op_shape(moment2))), param))
  136. else:
  137. success = F.depend(success, opt(param, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
  138. eps, gradient))
  139. return success
  140. @_adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",
  141. "Tensor", "Tensor")
  142. def _run_off_load_opt(opt, beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2):
  143. """Apply AdamOffload optimizer to the weight parameter using Tensor."""
  144. success = True
  145. delat_param = opt(moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, eps, gradient)
  146. success = F.depend(success, F.assign_add(param, delat_param))
  147. return success
  148. def _check_param_value(beta1, beta2, eps, prim_name):
  149. """Check the type of inputs."""
  150. validator.check_value_type("beta1", beta1, [float], prim_name)
  151. validator.check_value_type("beta2", beta2, [float], prim_name)
  152. validator.check_value_type("eps", eps, [float], prim_name)
  153. validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
  154. validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
  155. validator.check_positive_float(eps, "eps", prim_name)
  156. class Adam(Optimizer):
  157. r"""
  158. Updates gradients by the Adaptive Moment Estimation (Adam) algorithm.
  159. The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
  160. The updating formulas are as follows,
  161. .. math::
  162. \begin{array}{ll} \\
  163. m = \beta_1 * m + (1 - \beta_1) * g \\
  164. v = \beta_2 * v + (1 - \beta_2) * g * g \\
  165. l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
  166. w = w - l * \frac{m}{\sqrt{v} + \epsilon}
  167. \end{array}
  168. :math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
  169. :math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent
  170. `beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
  171. `beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
  172. :math:`\epsilon` represents `eps`.
  173. Note:
  174. When separating parameter groups, the weight decay in each group will be applied on the parameters if the
  175. weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
  176. on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
  177. To improve parameter groups performance, the customized order of parameters is supported.
  178. The sparse strategy is applied while the SparseGatherV2 operator is used for forward network.
  179. The sparse feature is under continuous development. If the sparse strategy wants to be executed on the host,
  180. set the target to the CPU.
  181. Args:
  182. params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
  183. the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
  184. "lr", "weight_decay" and "order_params" are the keys can be parsed.
  185. - params: Required. The value must be a list of `Parameter`.
  186. - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
  187. If not, the `learning_rate` in the API will be used.
  188. - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
  189. will be used. If not, the `weight_decay` in the API will be used.
  190. - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
  191. the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
  192. which in the 'order_params' must be in one of group parameters.
  193. - grad_centralization: Optional. If "grad_centralization" is in the keys, the set value will be used.
  194. If not, the `grad_centralization` is False by default. This parameter only works on the convolution layer.
  195. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
  196. When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
  197. the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
  198. use dynamic learning rate, the i-th learning rate will be calculated during the process of training
  199. according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
  200. dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
  201. equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
  202. Default: 1e-3.
  203. beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
  204. Default: 0.9.
  205. beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
  206. Default: 0.999.
  207. eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
  208. 1e-8.
  209. use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
  210. If true, updates of the var, m, and v tensors will be protected by a lock.
  211. If false, the result is unpredictable. Default: False.
  212. use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
  213. If true, update the gradients using NAG.
  214. If false, update the gradients without using NAG. Default: False.
  215. weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
  216. loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
  217. Inputs:
  218. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
  219. Outputs:
  220. Tensor[bool], the value is True.
  221. Supported Platforms:
  222. ``Ascend`` ``GPU``
  223. Examples:
  224. >>> net = Net()
  225. >>> #1) All parameters use the same learning rate and weight decay
  226. >>> optim = nn.Adam(params=net.trainable_params())
  227. >>>
  228. >>> #2) Use parameter groups and set different values
  229. >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
  230. >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
  231. >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
  232. ... {'params': no_conv_params, 'lr': 0.01},
  233. ... {'order_params': net.trainable_params()}]
  234. >>> optim = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0)
  235. >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
  236. >>> # centralization of True.
  237. >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
  238. >>> # centralization of False.
  239. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
  240. >>>
  241. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  242. >>> model = Model(net, loss_fn=loss, optimizer=optim)
  243. """
  244. def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
  245. use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
  246. super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale)
  247. _check_param_value(beta1, beta2, eps, self.cls_name)
  248. validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
  249. validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
  250. self.beta1 = Tensor(beta1, mstype.float32)
  251. self.beta2 = Tensor(beta2, mstype.float32)
  252. self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
  253. self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
  254. self.eps = Tensor(eps, mstype.float32)
  255. self.use_nesterov = use_nesterov
  256. self.use_locking = use_locking
  257. self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
  258. self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
  259. self._is_device = True
  260. self.hyper_map = C.HyperMap()
  261. self.opt = P.Adam(use_locking, use_nesterov)
  262. self.sparse_opt = P.FusedSparseAdam(use_locking, use_nesterov)
  263. self.sparse_opt.add_prim_attr("primitive_target", "CPU")
  264. self._ps_pull = P.Pull()
  265. self._ps_push = P.Push("Adam", [0, 1, 2])
  266. self._ps_push.add_prim_attr("use_nesterov", use_nesterov)
  267. def construct(self, gradients):
  268. params = self.parameters
  269. moment1 = self.moment1
  270. moment2 = self.moment2
  271. gradients = self.decay_weight(gradients)
  272. gradients = self.scale_grad(gradients)
  273. gradients = self._grad_sparse_indices_deduplicate(gradients)
  274. gradients = self.gradients_centralization(gradients)
  275. lr = self.get_lr()
  276. beta1_power = self.beta1_power * self.beta1
  277. self.beta1_power = beta1_power
  278. beta2_power = self.beta2_power * self.beta2
  279. self.beta2_power = beta2_power
  280. if self.is_group_lr:
  281. success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
  282. self.use_locking, self.use_nesterov, self._is_device,
  283. beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
  284. lr, gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable)
  285. else:
  286. success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
  287. self.use_locking, self.use_nesterov, self._is_device,
  288. beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
  289. gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable)
  290. return success
  291. @Optimizer.target.setter
  292. def target(self, value):
  293. """If the input value is set to "CPU", the parameters will be updated on the host using the Fused
  294. optimizer operation."""
  295. if not isinstance(value, str):
  296. raise TypeError("The value must be str type, but got value type is {}".format(type(value)))
  297. if value not in ('CPU', 'Ascend', 'GPU'):
  298. raise ValueError("The value must be 'CPU', 'Ascend' or 'GPU', but got value {}".format(value))
  299. if self._target == "CPU" and value in('Ascend', 'GPU'):
  300. raise ValueError("In the CPU environment, target cannot be set to 'GPU' and 'Ascend'.")
  301. if self._target == "Ascend" and value == 'GPU':
  302. raise ValueError("In the Ascend environment, target cannot be set to 'GPU'.")
  303. self._is_device = (value != 'CPU')
  304. self._target = value
  305. class AdamWeightDecay(Optimizer):
  306. """
  307. Implements the Adam algorithm to fix the weight decay.
  308. Note:
  309. When separating parameter groups, the weight decay in each group will be applied on the parameters if the
  310. weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
  311. on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
  312. To improve parameter groups performance, the customized order of parameters can be supported.
  313. Args:
  314. params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
  315. the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
  316. "lr", "weight_decay" and "order_params" are the keys can be parsed.
  317. - params: Required. The value must be a list of `Parameter`.
  318. - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
  319. If not, the `learning_rate` in the API will be used.
  320. - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
  321. will be used. If not, the `weight_decay` in the API will be used.
  322. - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
  323. the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
  324. which in the 'order_params' must be in one of group parameters.
  325. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
  326. When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
  327. the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
  328. use dynamic learning rate, the i-th learning rate will be calculated during the process of training
  329. according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
  330. dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
  331. equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
  332. Default: 1e-3.
  333. beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
  334. Should be in range (0.0, 1.0).
  335. beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
  336. Should be in range (0.0, 1.0).
  337. eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
  338. Should be greater than 0.
  339. weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
  340. Inputs:
  341. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
  342. Outputs:
  343. tuple[bool], all elements are True.
  344. Supported Platforms:
  345. ``Ascend`` ``GPU``
  346. Examples:
  347. >>> net = Net()
  348. >>> #1) All parameters use the same learning rate and weight decay
  349. >>> optim = nn.AdamWeightDecay(params=net.trainable_params())
  350. >>>
  351. >>> #2) Use parameter groups and set different values
  352. >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
  353. >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
  354. >>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
  355. ... {'params': no_conv_params, 'lr': 0.01},
  356. ... {'order_params': net.trainable_params()}]
  357. >>> optim = nn.AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0)
  358. >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
  359. >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
  360. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
  361. >>>
  362. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  363. >>> model = Model(net, loss_fn=loss, optimizer=optim)
  364. """
  365. def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
  366. super(AdamWeightDecay, self).__init__(learning_rate, params, weight_decay)
  367. _check_param_value(beta1, beta2, eps, self.cls_name)
  368. self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
  369. self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
  370. self.eps = Tensor(np.array([eps]).astype(np.float32))
  371. self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
  372. self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
  373. self.hyper_map = C.HyperMap()
  374. def construct(self, gradients):
  375. lr = self.get_lr()
  376. if self.is_group:
  377. if self.is_group_lr:
  378. optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
  379. lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
  380. gradients, self.decay_flags, self.optim_filter)
  381. else:
  382. optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
  383. self.weight_decay, self.parameters, self.moments1, self.moments2,
  384. gradients, self.decay_flags, self.optim_filter)
  385. else:
  386. optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
  387. self.parameters, self.moments1, self.moments2,
  388. gradients, self.decay_flags, self.optim_filter)
  389. if self.use_parallel:
  390. self.broadcast_params(optim_result)
  391. return optim_result
  392. class AdamOffload(Optimizer):
  393. r"""
  394. This optimizer will offload Adam optimizer to host CPU and keep parameters being updated on the device,
  395. to minimize the memory cost. Although that would bring about an increase of performance overhead,
  396. the optimizer could be used to run a larger model.
  397. The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
  398. The updating formulas are as follows,
  399. .. math::
  400. \begin{array}{ll} \\
  401. m = \beta_1 * m + (1 - \beta_1) * g \\
  402. v = \beta_2 * v + (1 - \beta_2) * g * g \\
  403. l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
  404. w = w - l * \frac{m}{\sqrt{v} + \epsilon}
  405. \end{array}
  406. :math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
  407. :math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent
  408. `beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
  409. `beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
  410. :math:`\epsilon` represents `eps`.
  411. Note:
  412. This optimizer only supports `GRAPH_MODE` currently.
  413. When separating parameter groups, the weight decay in each group will be applied on the parameters if the
  414. weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
  415. on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
  416. To improve parameter groups performance, the customized order of parameters is supported.
  417. Args:
  418. params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
  419. the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
  420. "lr", "weight_decay" and "order_params" are the keys can be parsed.
  421. - params: Required. The value must be a list of `Parameter`.
  422. - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
  423. If not, the `learning_rate` in the API will be used.
  424. - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
  425. will be used. If not, the `weight_decay` in the API will be used.
  426. - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
  427. the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
  428. which in the 'order_params' must be in one of group parameters.
  429. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
  430. When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
  431. the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
  432. use dynamic learning rate, the i-th learning rate will be calculated during the process of training
  433. according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
  434. dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
  435. equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
  436. Default: 1e-3.
  437. beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
  438. Default: 0.9.
  439. beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
  440. Default: 0.999.
  441. eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
  442. 1e-8.
  443. use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
  444. If true, updates of the var, m, and v tensors will be protected by a lock.
  445. If false, the result is unpredictable. Default: False.
  446. use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
  447. If true, update the gradients using NAG.
  448. If false, update the gradients without using NAG. Default: False.
  449. weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
  450. loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
  451. Inputs:
  452. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
  453. Outputs:
  454. Tensor[bool], the value is True.
  455. Supported Platforms:
  456. ``Ascend`` ``GPU`` ``CPU``
  457. Examples:
  458. >>> net = Net()
  459. >>> #1) All parameters use the same learning rate and weight decay
  460. >>> optim = nn.AdamOffload(params=net.trainable_params())
  461. >>>
  462. >>> #2) Use parameter groups and set different values
  463. >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
  464. >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
  465. >>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
  466. ... {'params': no_conv_params, 'lr': 0.01},
  467. ... {'order_params': net.trainable_params()}]
  468. >>> optim = nn.AdamOffload(group_params, learning_rate=0.1, weight_decay=0.0)
  469. >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
  470. >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
  471. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
  472. >>>
  473. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  474. >>> model = Model(net, loss_fn=loss, optimizer=optim)
  475. """
  476. def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
  477. use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
  478. super(AdamOffload, self).__init__(learning_rate, params, weight_decay, loss_scale)
  479. _check_param_value(beta1, beta2, eps, self.cls_name)
  480. validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
  481. validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
  482. self.beta1 = Tensor(beta1, mstype.float32)
  483. self.beta2 = Tensor(beta2, mstype.float32)
  484. self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
  485. self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
  486. self.eps = Tensor(eps, mstype.float32)
  487. self.use_nesterov = use_nesterov
  488. self.use_locking = use_locking
  489. self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
  490. self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
  491. self.hyper_map = C.HyperMap()
  492. self.opt = P.AdamNoUpdateParam(use_locking, use_nesterov)
  493. self.opt.add_prim_attr("primitive_target", "CPU")
  494. def construct(self, gradients):
  495. params = self.parameters
  496. moment1 = self.moment1
  497. moment2 = self.moment2
  498. gradients = self.decay_weight(gradients)
  499. gradients = self.scale_grad(gradients)
  500. lr = self.get_lr()
  501. beta1_power = self.beta1_power * self.beta1
  502. self.beta1_power = beta1_power
  503. beta2_power = self.beta2_power * self.beta2
  504. self.beta2_power = beta2_power
  505. if self.is_group_lr:
  506. success = self.map_(F.partial(_adam_opt, self.opt,
  507. beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
  508. lr, gradients, params, moment1, moment2)
  509. else:
  510. success = self.map_(F.partial(_adam_opt, self.opt,
  511. beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
  512. gradients, params, moment1, moment2)
  513. return success