You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.Adam.rst 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. mindspore.nn.Adam
  2. ==================
  3. .. py:class:: mindspore.nn.Adam(*args, **kwargs)
  4. ͨAdaptive Moment Estimation (Adam)㷨ݶȡ
  5. `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_
  6. ʽ£
  7. .. math::
  8. \begin{array}{ll} \\
  9. m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\
  10. v_{t+1} = \beta_2 * v_{t} + (1 - \beta_2) * g * g \\
  11. l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
  12. w_{t+1} = w_{t} - l * \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon}
  13. \end{array}
  14. :math:`m` һ `moment1` :math:`v` ڶ `moment2` :math:`g` `gradients` :math:`l` ӣ:math:`\beta_1,\beta_2` `beta1` `beta2` :math:`t` ²裬:math:`beta_1^t` :math:`beta_2^t` `beta1_power` `beta2_power` :math:`\alpha` `learning_rate` :math:`w` `params` :math:`\epsilon` `eps`
  15. .. note::
  16. ǰʹSparseGatherV2ӣŻִϡ㣬ͨ `target` ΪCPUhostϽϡ㡣
  17. ϡڳС
  18. ڲδʱŻõ `weight_decay` Ӧƺ"beta""gamma"ͨɵȨ˥ԡʱÿ `weight_decay` δãʹŻõ `weight_decay`
  19. ****
  20. - **params** (Union[list[Parameter], list[dict]]) - `Parameter` ɵбֵɵббԪֵʱֵļ"params""lr""weight_decay""grad_centralization""order_params"
  21. - **params** - ǰȨأֵ `Parameter` б
  22. - **lr** - ѡд"lr"ʹöӦֵΪѧϰʡûУʹŻõ `learning_rate` Ϊѧϰʡ
  23. - **weight_decay** - ѡд"weight_decayʹöӦֵΪȨ˥ֵûУʹŻõ `weight_decay` ΪȨ˥ֵ
  24. - **grad_centralization** - ѡд"grad_centralization"ʹöӦֵֵΪ͡ûУΪ `grad_centralization` ΪFalseòھ㡣
  25. - **order_params** - ѡӦֵԤڵIJ˳򡣵ʹò鹦ʱͨʹø `parameters` ˳ܡд"order_params"Ըе"order_params"еIJijһ `params` С
  26. - **learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): Ĭֵ1e-3
  27. - **float** - ̶ѧϰʡڵ㡣
  28. - **int** - ̶ѧϰʡڵ㡣ͻᱻתΪ
  29. - **Tensor** - DZһάǹ̶ѧϰʡһάǶ̬ѧϰʣiȡеiֵΪѧϰʡ
  30. - **Iterable** - ̬ѧϰʡiȡiֵΪѧϰʡ
  31. - **LearningRateSchedule** - ̬ѧϰʡѵУŻʹòstepΪ룬 `LearningRateSchedule` ʵ㵱ǰѧϰʡ
  32. - **beta1** (float) - `moment1` ָ˥ʡΧ0.0,1.0Ĭֵ0.9
  33. - **beta2** (float) - `moment2` ָ˥ʡΧ0.0,1.0Ĭֵ0.999
  34. - **eps** (float) - ӵĸУֵȶԡ0Ĭֵ1e-8
  35. - **use_locking** (bool) - ǷԲ¼ΪTrue `w` `m` `v` tensor½ܵıΪFalseԤ⡣ĬֵFalse
  36. - **use_nesterov** (bool) - ǷʹNesterov Accelerated Gradient (NAG)㷨ݶȡΪTrueʹNAGݶȡΪFalseڲʹNAG¸ݶȡĬֵFalse
  37. - **weight_decay** (float) - Ȩ˥L2 penaltyڵ0Ĭֵ0.0
  38. - **loss_scale** (float) - ݶϵ0 `loss_scale` תΪͨʹĬֵѵʱʹ `FixedLossScaleManager` `FixedLossScaleManager` `drop_overflow_update` ΪFalseʱֵҪ `FixedLossScaleManager` е `loss_scale` ͬйظϸϢclass`mindspore.FixedLossScaleManager` Ĭֵ1.0
  39. **룺**
  40. **gradients** (tuple[Tensor]) - `params` ݶȣ״shape `params` ͬ
  41. ****
  42. Tensor[bool]ֵΪTrue
  43. **쳣**
  44. - **TypeError** - `learning_rate` intfloatTensorIterableLearningRateSchedule
  45. - **TypeError** - `parameters` ԪزParameterֵ䡣
  46. - **TypeError** - `beta1` `beta2` `eps` `loss_scale` float
  47. - **TypeError** - `weight_decay` floatint
  48. - **TypeError** - `use_locking` `use_nesterov` bool
  49. - **ValueError** - `loss_scale` `eps` Сڻ0
  50. - **ValueError** - `beta1` `beta2` ڣ0.0,1.0Χڡ
  51. - **ValueError** - `weight_decay` С0
  52. **֧ƽ̨**
  53. ``Ascend`` ``GPU`` ``CPU``
  54. ****
  55. >>> net = Net()
  56. >>> #1) вʹͬѧϰʺȨ˥
  57. >>> optim = nn.Adam(params=net.trainable_params())
  58. >>>
  59. >>> #2) ʹò鲢òֵͬ
  60. >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
  61. >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
  62. >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
  63. ... {'params': no_conv_params, 'lr': 0.01},
  64. ... {'order_params': net.trainable_params()}]
  65. >>> optim = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0)
  66. >>> # conv_params齫ʹŻеѧϰ0.1Ȩ˥0.01ݶĻTrue
  67. >>> # no_conv_params齫ʹøѧϰ0.01ŻеȨ˥0.0ݶĻʹĬֵFalse
  68. >>> # Ż"order_params"õIJ˳²
  69. >>>
  70. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  71. >>> model = Model(net, loss_fn=loss, optimizer=optim)
  72. .. py:method:: target
  73. :property:
  74. ָhostϻ豸deviceϸ²Ϊstrֻ'CPU''Ascend''GPU'