You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.Optimizer.rst 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. mindspore.nn.Optimizer
  2. ======================
  3. .. py:class:: mindspore.nn.Optimizer(learning_rate, parameters, weight_decay=0.0, loss_scale=1.0)
  4. ڲµŻࡣҪֱʹ࣬ʵһࡣ
  5. Żֲ֧顣ʱÿòͬѧϰʣ`lr` Ȩ˥`weight_decay`ݶĻ`grad_centralization`ԡ
  6. .. note::
  7. ڲδʱŻõ `weight_decay` Ӧƺ"beta""gamma"ͨɵȨ˥ԡʱÿ `weight_decay` δãʹŻõ `weight_decay`
  8. ****
  9. - **learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]):
  10. - **float** - ̶ѧϰʡڵ㡣
  11. - **int** - ̶ѧϰʡڵ㡣ͻᱻתΪ
  12. - **Tensor** - DZһάǹ̶ѧϰʡһάǶ̬ѧϰʣiȡеiֵΪѧϰʡ
  13. - **Iterable** - ̬ѧϰʡiȡiֵΪѧϰʡ
  14. - **LearningRateSchedule** - ̬ѧϰʡѵУŻʹòstepΪ룬 `LearningRateSchedule` ʵ㵱ǰѧϰʡ
  15. - **parameters (Union[list[Parameter], list[dict]])** - `Parameter` ɵбֵɵббԪֵʱֵļ"params""lr""weight_decay""grad_centralization""order_params"
  16. - **params** - ǰȨأֵ `Parameter` б
  17. - **lr** - ѡд"lr"ʹöӦֵΪѧϰʡûУʹŻõ `learning_rate` Ϊѧϰʡ
  18. - **weight_decay** - ѡд"weight_decayʹöӦֵΪȨ˥ֵûУʹŻõ `weight_decay` ΪȨ˥ֵ
  19. - **grad_centralization** - ѡд"grad_centralization"ʹöӦֵֵΪ͡ûУΪ `grad_centralization` ΪFalseòھ㡣
  20. - **order_params** - ѡӦֵԤڵIJ˳򡣵ʹò鹦ʱͨʹø `parameters` ˳ܡд"order_params"Ըе"order_params"еIJijһ `params` С
  21. - **weight_decay** (Union[float, int]) - Ȩ˥򸡵ֵڻ0 `weight_decay` תΪĬֵ0.0
  22. - **loss_scale** (float) - ݶϵ0 `loss_scale` תΪͨʹĬֵѵʱʹ `FixedLossScaleManager` `FixedLossScaleManager ` `drop_overflow_update` ΪFalseʱֵҪ `FixedLossScaleManager` е `loss_scale` ͬйظϸϢclass`mindspore.FixedLossScaleManager`Ĭֵ1.0
  23. **쳣**
  24. - **TypeError** - `learning_rate` intfloatTensorIterableLearningRateSchedule
  25. - **TypeError** - `parameters` ԪزParameterֵ䡣
  26. - **TypeError** - `loss_scale` float
  27. - **TypeError** - `weight_decay` floatint
  28. - **ValueError** - `loss_scale` Сڻ0
  29. - **ValueError** - `weight_decay` С0
  30. - **ValueError** - `learning_rate` һTensorTensorάȴ1
  31. **֧ƽ̨**
  32. ``Ascend`` ``GPU`` ``CPU``
  33. .. py:method:: broadcast_params(optim_result)
  34. ˳в㲥
  35. ****
  36. **optim_result** (bool) - ½֤ɺִв㲥
  37. **أ**
  38. bool״̬־
  39. .. py:method:: decay_weight(gradients)
  40. ˥Ȩء
  41. һּѧϰģ͹ϵķ̳ :class:`mindspore.nn.Optimizer` ԶŻʱɵøýӿڽȨ˥
  42. ****
  43. **gradients** (tuple[Tensor]) - ݶȣ״shapeͬ
  44. **أ**
  45. tuple[Tensor]˥Ȩغݶȡ
  46. .. py:method:: get_lr()
  47. Żøýӿڻȡǰ裨stepѧϰʡ̳ :class:`mindspore.nn.Optimizer` ԶŻʱڲǰøýӿڻȡѧϰʡ
  48. **أ**
  49. floatǰѧϰʡ
  50. .. py:method:: get_lr_parameter(param)
  51. ʹ鹦ܣΪͬòͬѧϰʱȡָѧϰʡ
  52. ****
  53. **param** (Union[Parameter, list[Parameter]]) - `Parameter` `Parameter` б
  54. **أ**
  55. Parameter `Parameter` `Parameter` бʹ˶̬ѧϰʣڼѧϰʵ `LearningRateSchedule` `LearningRateSchedule` б
  56. ****
  57. >>> from mindspore import nn
  58. >>> net = Net()
  59. >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
  60. >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
  61. >>> group_params = [{'params': conv_params, 'lr': 0.05},
  62. ... {'params': no_conv_params, 'lr': 0.01}]
  63. >>> optim = nn.Momentum(group_params, learning_rate=0.1, momentum=0.9, weight_decay=0.0)
  64. >>> conv_lr = optim.get_lr_parameter(conv_params)
  65. >>> print(conv_lr[0].asnumpy())
  66. 0.05
  67. .. py:method:: gradients_centralization(gradients)
  68. ݶĻ
  69. һŻѧϰģѵٶȵķ̳ :class:`mindspore.nn.Optimizer` ԶŻʱɵøýӿڽݶĻ
  70. ****
  71. **gradients** (tuple[Tensor]) - ݶȣ״shapeͬ
  72. **أ**
  73. tuple[Tensor]ݶĻݶȡ
  74. .. py:method:: scale_grad(gradients)
  75. ڻϾȳԭݶȡ
  76. ̳ :class:`mindspore.nn.Optimizer` ԶŻʱɵøýӿڻԭݶȡ
  77. ****
  78. **gradients** (tuple[Tensor]) - ݶȣ״shapeͬ
  79. **أ**
  80. tuple[Tensor]ԭݶȡ
  81. .. py:method:: target
  82. :property:
  83. ָhostϻ豸deviceϸ²Ϊstrֻ'CPU''Ascend''GPU'
  84. .. py:method:: unique
  85. :property:
  86. ԱʾǷŻнݶȥأͨϡ硣ݶϡΪTrueǰϡѶȨȥأݶdzܵģΪFalseδʱĬֵΪTrue