You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.Adam.txt 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. Class mindspore.nn.Adam(*args, **kwargs)
  2. ͨAdaptive Moment Estimation (Adam)㷨ݶȡ
  3. `Adam: A Method for Stochastic Optimization <https://arxiv.org/abs/1412.6980>`_
  4. ʽ£
  5. .. math::
  6. \begin{array}{ll} \\
  7. m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\
  8. v_{t+1} = \beta_2 * v_{t} + (1 - \beta_2) * g * g \\
  9. l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\
  10. w_{t+1} = w_{t} - l * \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon}
  11. \end{array}
  12. :math:`m`һ`moment1`:math:`v`ڶ`moment2`:math:`g``gradients`:math:`l`ӣ:math:`\beta_1,\beta_2``beta1``beta2`:math:`t`²裬:math:`beta_1^t`:math:`beta_2^t``beta1_power``beta2_power`:math:`\alpha``learning_rate`:math:`w``params`:math:`\epsilon``eps`
  13. ע
  14. ǰʹSparseGatherV2ӣŻִϡ㣬ͨ`target`ΪCPUhostϽϡ㡣
  15. ϡڳС
  16. ڲδʱŻõ`weight_decay`Ӧƺ"beta""gamma"ͨɵȨ˥ԡʱÿ`weight_decay`δãʹŻõ`weight_decay`
  17. params (Union[list[Parameter], list[dict]]): `Parameter`ɵбֵɵббԪֵʱֵļ"params""lr""weight_decay""grad_centralization""order_params"
  18. - params: ǰȨأֵ`Parameter`б
  19. - lr: ѡд"lr"ʹöӦֵΪѧϰʡ
  20. ûУʹŻõ`learning_rate`Ϊѧϰʡ
  21. - weight_decay: ѡд"weight_decayʹöӦֵΪȨ˥ֵûУʹŻõ`weight_decay`ΪȨ˥ֵ
  22. - grad_centralization: ѡд"grad_centralization"ʹöӦֵֵΪ͡ûУΪ`grad_centralization`ΪFalse
  23. òھ㡣
  24. - order_params: ѡӦֵԤڵIJ˳򡣵ʹò鹦ʱͨʹø`parameters`˳ܡ
  25. д"order_params"Ըе"order_params"еIJijһ`params`С
  26. learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]): Ĭֵ1e-3
  27. - float: ̶ѧϰʡڵ㡣
  28. - int: ̶ѧϰʡڵ㡣ͻᱻתΪ
  29. - Tensor: DZһάǹ̶ѧϰʡһάǶ̬ѧϰʣiȡеiֵΪѧϰʡ
  30. - Iterable: ̬ѧϰʡiȡiֵΪѧϰʡ
  31. - LearningRateSchedule: ̬ѧϰʡѵУŻʹòstepΪ룬`LearningRateSchedule`ʵ㵱ǰѧϰʡ
  32. beta1 (float): `moment1`ָ˥ʡΧ0.0,1.0
  33. Ĭֵ0.9
  34. beta2 (float): `moment2`ָ˥ʡΧ0.0,1.0
  35. Ĭֵ0.999
  36. eps (float): ӵĸУֵȶԡ0Ĭֵ1e-8
  37. use_locking (bool): ǷԲ¼
  38. ΪTrue`w``m``v`tensor½ܵı
  39. ΪFalseԤ⡣ĬֵFalse
  40. use_nesterov (bool): ǷʹNesterov Accelerated Gradient (NAG)㷨ݶȡ
  41. ΪTrueʹNAGݶȡ
  42. ΪFalseڲʹNAG¸ݶȡĬֵFalse
  43. weight_decay (float): Ȩ˥L2 penaltyڵ0Ĭֵ0.0
  44. loss_scale (float): ݶϵ0`loss_scale`תΪͨʹĬֵѵʱʹ`FixedLossScaleManager``FixedLossScaleManager``drop_overflow_update`ΪFalseʱֵҪ`FixedLossScaleManager`е`loss_scale`ͬйظϸϢclass`mindspore.FixedLossScaleManager`
  45. Ĭֵ1.0
  46. - **gradients** (tuple[Tensor]) - `params`ݶȣ״shape`params`ͬ
  47. Tensor[bool]ֵΪTrue
  48. TypeError: `learning_rate`intfloatTensorIterableLearningRateSchedule
  49. TypeError: `parameters`ԪزParameterֵ䡣
  50. TypeError: `beta1``beta2``eps``loss_scale`float
  51. TypeError: `weight_decay`floatint
  52. TypeError: `use_locking``use_nesterov`bool
  53. ValueError: `loss_scale``eps`Сڻ0
  54. ValueError: `beta1``beta2`ڣ0.0,1.0Χڡ
  55. ValueError: `weight_decay`С0
  56. ֧ƽ̨
  57. ``Ascend`` ``GPU`` ``CPU``
  58. ʾ
  59. >>> net = Net()
  60. >>> #1) вʹͬѧϰʺȨ˥
  61. >>> optim = nn.Adam(params=net.trainable_params())
  62. >>>
  63. >>> #2) ʹò鲢òֵͬ
  64. >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
  65. >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
  66. >>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
  67. ... {'params': no_conv_params, 'lr': 0.01},
  68. ... {'order_params': net.trainable_params()}]
  69. >>> optim = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0)
  70. >>> # conv_params齫ʹŻеѧϰ0.1Ȩ˥0.01ݶĻTrue
  71. >>> # no_conv_params齫ʹøѧϰ0.01ŻеȨ˥0.0ݶĻʹĬֵFalse
  72. >>> # Ż"order_params"õIJ˳²
  73. >>>
  74. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  75. >>> model = Model(net, loss_fn=loss, optimizer=optim)
  76. target
  77. ָhostϻ豸deviceϸ²Ϊstrֻ'CPU''Ascend''GPU'