You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

adam.py 24 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """adam"""
  16. import numpy as np
  17. from mindspore.common import dtype as mstype
  18. from mindspore.common.initializer import initializer
  19. from mindspore.ops import operations as P
  20. from mindspore.ops import composite as C
  21. from mindspore.ops import functional as F
  22. from mindspore.common.parameter import Parameter
  23. from mindspore.common.tensor import Tensor
  24. from mindspore._checkparam import Validator as validator
  25. from mindspore._checkparam import Rel
  26. from .optimizer import Optimizer
  27. _adam_opt = C.MultitypeFuncGraph("adam_opt")
  28. _scaler_one = Tensor(1, mstype.int32)
  29. _scaler_ten = Tensor(10, mstype.float32)
  30. @_adam_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor",
  31. "Tensor", "Bool", "Bool")
  32. def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, decay_flag, optim_filter):
  33. """
  34. Update parameters.
  35. Args:
  36. beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
  37. beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
  38. eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
  39. lr (Tensor): Learning rate.
  40. weight_decay (Number): Weight decay. Should be equal to or greater than 0.
  41. param (Tensor): Parameters.
  42. m (Tensor): m value of parameters.
  43. v (Tensor): v value of parameters.
  44. gradient (Tensor): Gradient of parameters.
  45. decay_flag (bool): Applies weight decay or not.
  46. optim_filter (bool): Applies parameter update or not.
  47. Returns:
  48. Tensor, the new value of v after updating.
  49. """
  50. if optim_filter:
  51. op_mul = P.Mul()
  52. op_square = P.Square()
  53. op_sqrt = P.Sqrt()
  54. op_cast = P.Cast()
  55. op_reshape = P.Reshape()
  56. op_shape = P.Shape()
  57. param_fp32 = op_cast(param, mstype.float32)
  58. m_fp32 = op_cast(m, mstype.float32)
  59. v_fp32 = op_cast(v, mstype.float32)
  60. gradient_fp32 = op_cast(gradient, mstype.float32)
  61. next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
  62. - beta1, gradient_fp32)
  63. next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
  64. - beta2, op_square(gradient_fp32))
  65. update = next_m / (eps + op_sqrt(next_v))
  66. if decay_flag:
  67. update = op_mul(weight_decay, param_fp32) + update
  68. update_with_lr = op_mul(lr, update)
  69. next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
  70. next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
  71. next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m))))
  72. next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v))))
  73. return op_cast(next_param, F.dtype(param))
  74. return gradient
  75. @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
  76. "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool")
  77. def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
  78. beta2_power, beta1, beta2, eps, lr, gradient, params, m, v, ps_parameter):
  79. """Apply sparse adam optimizer to the weight parameter when the gradient is sparse."""
  80. success = True
  81. indices = gradient.indices
  82. values = gradient.values
  83. if ps_parameter:
  84. op_shape = P.Shape()
  85. shapes = (op_shape(params), op_shape(m), op_shape(v),
  86. op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1),
  87. op_shape(beta2), op_shape(eps), op_shape(values), op_shape(indices))
  88. success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2,
  89. eps, values, indices), shapes), params))
  90. return success
  91. if not target:
  92. success = F.depend(success, sparse_opt(params, m, v, beta1_power, beta2_power, lr, beta1, beta2,
  93. eps, values, indices))
  94. else:
  95. op_mul = P.Mul()
  96. op_square = P.Square()
  97. op_sqrt = P.Sqrt()
  98. scatter_add = P.ScatterAdd(use_locking)
  99. assign_m = F.assign(m, op_mul(beta1, m))
  100. assign_v = F.assign(v, op_mul(beta2, v))
  101. grad_indices = gradient.indices
  102. grad_value = gradient.values
  103. next_m = scatter_add(m,
  104. grad_indices,
  105. op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
  106. next_v = scatter_add(v,
  107. grad_indices,
  108. op_mul(F.tuple_to_array((1.0,)) - beta2, op_square(grad_value)))
  109. if use_nesterov:
  110. m_temp = next_m * _scaler_ten
  111. assign_m_nesterov = F.assign(m, op_mul(beta1, next_m))
  112. div_value = scatter_add(m,
  113. op_mul(grad_indices, _scaler_one),
  114. op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value))
  115. param_update = div_value / (op_sqrt(next_v) + eps)
  116. m_recover = F.assign(m, m_temp / _scaler_ten)
  117. F.control_depend(m_temp, assign_m_nesterov)
  118. F.control_depend(assign_m_nesterov, div_value)
  119. F.control_depend(param_update, m_recover)
  120. else:
  121. param_update = next_m / (op_sqrt(next_v) + eps)
  122. lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power)
  123. next_param = params - lr_t * param_update
  124. F.control_depend(assign_m, next_m)
  125. F.control_depend(assign_v, next_v)
  126. success = F.depend(success, F.assign(params, next_param))
  127. success = F.depend(success, F.assign(m, next_m))
  128. success = F.depend(success, F.assign(v, next_v))
  129. return success
  130. @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor",
  131. "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool")
  132. def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power,
  133. beta2_power, beta1, beta2, eps, lr, gradient, params, moment1, moment2, ps_parameter):
  134. """Apply adam optimizer to the weight parameter using Tensor."""
  135. success = True
  136. if ps_parameter:
  137. op_shape = P.Shape()
  138. success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient),
  139. (op_shape(params), op_shape(moment1), op_shape(moment2))), params))
  140. else:
  141. success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2,
  142. eps, gradient))
  143. return success
  144. def _check_param_value(beta1, beta2, eps, prim_name):
  145. """Check the type of inputs."""
  146. validator.check_value_type("beta1", beta1, [float], prim_name)
  147. validator.check_value_type("beta2", beta2, [float], prim_name)
  148. validator.check_value_type("eps", eps, [float], prim_name)
  149. validator.check_float_range(beta1, 0.0, 1.0, Rel.INC_NEITHER, "beta1", prim_name)
  150. validator.check_float_range(beta2, 0.0, 1.0, Rel.INC_NEITHER, "beta2", prim_name)
  151. validator.check_positive_float(eps, "eps", prim_name)
  152. class Adam(Optimizer):
  153. r"""
  154. Updates gradients by the Adaptive Moment Estimation (Adam) algorithm.
  155. The Adam algorithm is proposed in `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_.
  156. The updating formulas are as follows,
  157. .. math::
  158. \begin{array}{ll} \\
  159. m = \beta_1 * m + (1 - \beta_1) * g \\
  160. v = \beta_2 * v + (1 - \beta_2) * g * g \\
  161. l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
  162. w = w - l * \frac{m}{\sqrt{v} + \epsilon}
  163. \end{array}
  164. :math:`m` represents the 1st moment vector `moment1`, :math:`v` represents the 2nd moment vector `moment2`,
  165. :math:`g` represents `gradients`, :math:`l` represents scaling factor `lr`, :math:`\beta_1, \beta_2` represent
  166. `beta1` and `beta2`, :math:`t` represents updating step while :math:`beta_1^t` and :math:`beta_2^t` represent
  167. `beta1_power` and `beta2_power`, :math:`\alpha` represents `learning_rate`, :math:`w` represents `params`,
  168. :math:`\epsilon` represents `eps`.
  169. Note:
  170. When separating parameter groups, the weight decay in each group will be applied on the parameters if the
  171. weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
  172. on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
  173. To improve parameter groups performance, the customized order of parameters is supported.
  174. The sparse strategy is applied while the SparseGatherV2 operator is used for forward network.
  175. The sparse feature is under continuous development. If the sparse strategy wants to be executed on the host,
  176. set the target to the CPU.
  177. Args:
  178. params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
  179. the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
  180. "lr", "weight_decay" and "order_params" are the keys can be parsed.
  181. - params: Required. The value must be a list of `Parameter`.
  182. - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
  183. If not, the `learning_rate` in the API will be used.
  184. - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
  185. will be used. If not, the `weight_decay` in the API will be used.
  186. - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
  187. the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
  188. which in the 'order_params' must be in one of group parameters.
  189. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
  190. When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
  191. the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
  192. use dynamic learning rate, the i-th learning rate will be calculated during the process of training
  193. according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
  194. dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
  195. equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
  196. Default: 1e-3.
  197. beta1 (float): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
  198. Default: 0.9.
  199. beta2 (float): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
  200. Default: 0.999.
  201. eps (float): Term added to the denominator to improve numerical stability. Should be greater than 0. Default:
  202. 1e-8.
  203. use_locking (bool): Whether to enable a lock to protect variable tensors from being updated.
  204. If true, updates of the var, m, and v tensors will be protected by a lock.
  205. If false, the result is unpredictable. Default: False.
  206. use_nesterov (bool): Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients.
  207. If true, update the gradients using NAG.
  208. If false, update the gradients without using NAG. Default: False.
  209. weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
  210. loss_scale (float): A floating point value for the loss scale. Should be greater than 0. Default: 1.0.
  211. Inputs:
  212. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
  213. Outputs:
  214. Tensor[bool], the value is True.
  215. Examples:
  216. >>> net = Net()
  217. >>> #1) All parameters use the same learning rate and weight decay
  218. >>> optim = nn.Adam(params=net.trainable_params())
  219. >>>
  220. >>> #2) Use parameter groups and set different values
  221. >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
  222. >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
  223. >>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
  224. >>> {'params': no_conv_params, 'lr': 0.01},
  225. >>> {'order_params': net.trainable_params()}]
  226. >>> optim = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0)
  227. >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
  228. >>> # The no_conv_params's parameters will use learning rate of 0.01 and defaule weight decay of 0.0.
  229. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
  230. >>>
  231. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  232. >>> model = Model(net, loss_fn=loss, optimizer=optim)
  233. """
  234. def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
  235. use_nesterov=False, weight_decay=0.0, loss_scale=1.0):
  236. super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale)
  237. _check_param_value(beta1, beta2, eps, self.cls_name)
  238. validator.check_value_type("use_locking", use_locking, [bool], self.cls_name)
  239. validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name)
  240. self.beta1 = Tensor(beta1, mstype.float32)
  241. self.beta2 = Tensor(beta2, mstype.float32)
  242. self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
  243. self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
  244. self.eps = Tensor(eps, mstype.float32)
  245. self.use_nesterov = use_nesterov
  246. self.use_locking = use_locking
  247. self.moment1 = self.parameters.clone(prefix="moment1", init='zeros')
  248. self.moment2 = self.parameters.clone(prefix="moment2", init='zeros')
  249. self._is_device = True
  250. self.hyper_map = C.HyperMap()
  251. self.opt = P.Adam(use_locking, use_nesterov)
  252. self.sparse_opt = P.FusedSparseAdam(use_locking, use_nesterov)
  253. self.sparse_opt.add_prim_attr("primitive_target", "CPU")
  254. self._ps_pull = P.Pull()
  255. self._ps_push = P.Push("Adam", [0, 1, 2])
  256. self._ps_push.add_prim_attr("use_nesterov", use_nesterov)
  257. def construct(self, gradients):
  258. params = self.parameters
  259. moment1 = self.moment1
  260. moment2 = self.moment2
  261. gradients = self.decay_weight(gradients)
  262. gradients = self.scale_grad(gradients)
  263. gradients = self._grad_sparse_indices_deduplicate(gradients)
  264. lr = self.get_lr()
  265. beta1_power = self.beta1_power * self.beta1
  266. self.beta1_power = beta1_power
  267. beta2_power = self.beta2_power * self.beta2
  268. self.beta2_power = beta2_power
  269. if self.is_group_lr:
  270. success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
  271. self.use_locking, self.use_nesterov, self._is_device,
  272. beta1_power, beta2_power, self.beta1, self.beta2, self.eps),
  273. lr, gradients, params, moment1, moment2, self.ps_parameters)
  274. else:
  275. success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
  276. self.use_locking, self.use_nesterov, self._is_device,
  277. beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr),
  278. gradients, params, moment1, moment2, self.ps_parameters)
  279. return success
  280. @Optimizer.target.setter
  281. def target(self, value):
  282. """If the input value is set to "CPU", the parameters will be updated on the host using the Fused
  283. optimizer operation."""
  284. if value not in ('CPU', 'Ascend'):
  285. raise ValueError("The value must be 'CPU' or 'Ascend', but got value {}".format(value))
  286. self._is_device = (value != 'CPU')
  287. self._target = value
  288. class AdamWeightDecay(Optimizer):
  289. """
  290. Implements the Adam algorithm to fix the weight decay.
  291. Note:
  292. When separating parameter groups, the weight decay in each group will be applied on the parameters if the
  293. weight decay is positive. When not separating parameter groups, the `weight_decay` in the API will be applied
  294. on the parameters without 'beta' or 'gamma' in their names if `weight_decay` is positive.
  295. To improve parameter groups performance, the customized order of parameters can be supported.
  296. Args:
  297. params (Union[list[Parameter], list[dict]]): When the `params` is a list of `Parameter` which will be updated,
  298. the element in `params` must be class `Parameter`. When the `params` is a list of `dict`, the "params",
  299. "lr", "weight_decay" and "order_params" are the keys can be parsed.
  300. - params: Required. The value must be a list of `Parameter`.
  301. - lr: Optional. If "lr" is in the keys, the value of the corresponding learning rate will be used.
  302. If not, the `learning_rate` in the API will be used.
  303. - weight_decay: Optional. If "weight_decay" is in the keys, the value of the corresponding weight decay
  304. will be used. If not, the `weight_decay` in the API will be used.
  305. - order_params: Optional. If "order_params" is in the keys, the value must be the order of parameters and
  306. the order will be followed in the optimizer. There are no other keys in the `dict` and the parameters
  307. which in the 'order_params' must be in one of group parameters.
  308. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): A value or a graph for the learning rate.
  309. When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then
  310. the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule,
  311. use dynamic learning rate, the i-th learning rate will be calculated during the process of training
  312. according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero
  313. dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be
  314. equal to or greater than 0. If the type of `learning_rate` is int, it will be converted to float.
  315. Default: 1e-3.
  316. beta1 (float): The exponential decay rate for the 1st moment estimations. Default: 0.9.
  317. Should be in range (0.0, 1.0).
  318. beta2 (float): The exponential decay rate for the 2nd moment estimations. Default: 0.999.
  319. Should be in range (0.0, 1.0).
  320. eps (float): Term added to the denominator to improve numerical stability. Default: 1e-6.
  321. Should be greater than 0.
  322. weight_decay (float): Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.
  323. Inputs:
  324. - **gradients** (tuple[Tensor]) - The gradients of `params`, the shape is the same as `params`.
  325. Outputs:
  326. tuple[bool], all elements are True.
  327. Examples:
  328. >>> net = Net()
  329. >>> #1) All parameters use the same learning rate and weight decay
  330. >>> optim = nn.AdamWeightDecay(params=net.trainable_params())
  331. >>>
  332. >>> #2) Use parameter groups and set different values
  333. >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
  334. >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
  335. >>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
  336. >>> {'params': no_conv_params, 'lr': 0.01},
  337. >>> {'order_params': net.trainable_params()}]
  338. >>> optim = nn.AdamWeightDecay(group_params, learning_rate=0.1, weight_decay=0.0)
  339. >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
  340. >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
  341. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
  342. >>>
  343. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  344. >>> model = Model(net, loss_fn=loss, optimizer=optim)
  345. """
  346. def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0):
  347. super(AdamWeightDecay, self).__init__(learning_rate, params, weight_decay)
  348. _check_param_value(beta1, beta2, eps, self.cls_name)
  349. self.beta1 = Tensor(np.array([beta1]).astype(np.float32))
  350. self.beta2 = Tensor(np.array([beta2]).astype(np.float32))
  351. self.eps = Tensor(np.array([eps]).astype(np.float32))
  352. self.moments1 = self.parameters.clone(prefix="adam_m", init='zeros')
  353. self.moments2 = self.parameters.clone(prefix="adam_v", init='zeros')
  354. self.hyper_map = C.HyperMap()
  355. def construct(self, gradients):
  356. lr = self.get_lr()
  357. if self.is_group:
  358. if self.is_group_lr:
  359. optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps),
  360. lr, self.weight_decay, self.parameters, self.moments1, self.moments2,
  361. gradients, self.decay_flags, self.optim_filter)
  362. else:
  363. optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr),
  364. self.weight_decay, self.parameters, self.moments1, self.moments2,
  365. gradients, self.decay_flags, self.optim_filter)
  366. else:
  367. optim_result = self.hyper_map(F.partial(_adam_opt, self.beta1, self.beta2, self.eps, lr, self.weight_decay),
  368. self.parameters, self.moments1, self.moments2,
  369. gradients, self.decay_flags, self.optim_filter)
  370. if self.use_parallel:
  371. self.broadcast_params(optim_result)
  372. return optim_result