You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.nn.Optimizer.txt 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. Class mindspore.nn.Optimizer(learning_rate, parameters, weight_decay=0.0, loss_scale=1.0)
  2. ڲµŻࡣҪֱʹ࣬ʵһࡣ
  3. Żֲ֧顣ʱÿòͬѧϰʣ`lr`Ȩ˥`weight_decay`ݶĻ`grad_centralization`ԡ
  4. ע
  5. ڲδʱŻõ`weight_decay`Ӧƺ"beta""gamma"ͨɵȨ˥ԡʱÿ`weight_decay`δãʹŻõ`weight_decay`
  6. learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]):
  7. - float: ̶ѧϰʡڵ㡣
  8. - int: ̶ѧϰʡڵ㡣ͻᱻתΪ
  9. - Tensor: DZһάǹ̶ѧϰʡһάǶ̬ѧϰʣiȡеiֵΪѧϰʡ
  10. - Iterable: ̬ѧϰʡiȡiֵΪѧϰʡ
  11. - LearningRateSchedule: ̬ѧϰʡѵУŻʹòstepΪ룬`LearningRateSchedule`ʵ㵱ǰѧϰʡ
  12. parameters (Union[list[Parameter], list[dict]]): `Parameter`ɵбֵɵббԪֵʱֵļ"params""lr""weight_decay""grad_centralization""order_params"
  13. - params: ǰȨأֵ`Parameter`б
  14. - lr: ѡд"lr"ʹöӦֵΪѧϰʡ
  15. ûУʹŻõ`learning_rate`Ϊѧϰʡ
  16. - weight_decay: ѡд"weight_decayʹöӦֵΪȨ˥ֵûУʹŻõ`weight_decay`ΪȨ˥ֵ
  17. - grad_centralization: ѡд"grad_centralization"ʹöӦֵֵΪ͡ûУΪ`grad_centralization`ΪFalse
  18. òھ㡣
  19. - order_params: ѡӦֵԤڵIJ˳򡣵ʹò鹦ʱͨʹø`parameters`˳ܡ
  20. д"order_params"Ըе"order_params"еIJijһ`params`С
  21. weight_decay (Union[float, int]): Ȩ˥򸡵ֵ
  22. ڻ0
  23. `weight_decay`תΪĬֵ0.0
  24. loss_scale (float): ݶϵ0`loss_scale`תΪͨʹĬֵѵʱʹ`FixedLossScaleManager``FixedLossScaleManager``drop_overflow_update`ΪFalseʱֵҪ`FixedLossScaleManager`е`loss_scale`ͬйظϸϢclass`mindspore.FixedLossScaleManager`
  25. Ĭֵ1.0
  26. TypeError: `learning_rate`intfloatTensorIterableLearningRateSchedule
  27. TypeError: `parameters`ԪزParameterֵ䡣
  28. TypeError: `loss_scale`float
  29. TypeError: `weight_decay`floatint
  30. ValueError: `loss_scale`Сڻ0
  31. ValueError: `weight_decay`С0
  32. ValueError: `learning_rate`һTensorTensorάȴ1
  33. ֵ֧ƽ̨
  34. ``Ascend`` ``GPU`` ``CPU``
  35. broadcast_params(optim_result)
  36. ˳в㲥
  37. optim_result (bool): ½֤ɺִв㲥
  38. أ
  39. bool״̬־
  40. decay_weight(gradients)
  41. ˥Ȩء
  42. һּѧϰģ͹ϵķ̳ :class:`mindspore.nn.Optimizer`ԶŻʱɵøýӿڽȨ˥
  43. gradients (tuple[Tensor]): ݶȣ״shapeͬ
  44. أ
  45. tuple[Tensor]˥Ȩغݶȡ
  46. get_lr()
  47. Żøýӿڻȡǰ裨stepѧϰʡ̳ :class:`mindspore.nn.Optimizer`ԶŻʱڲǰøýӿڻȡѧϰʡ
  48. أ
  49. floatǰѧϰʡ
  50. get_lr_parameter(param)
  51. ʹ鹦ܣΪͬòͬѧϰʱȡָѧϰʡ
  52. param (Union[Parameter, list[Parameter]]): `Parameter``Parameter`б
  53. أ
  54. Parameter`Parameter``Parameter`бʹ˶̬ѧϰʣڼѧϰʵ`LearningRateSchedule``LearningRateSchedule`б
  55. ʾ
  56. >>> from mindspore import nn
  57. >>> net = Net()
  58. >>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
  59. >>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
  60. >>> group_params = [{'params': conv_params, 'lr': 0.05},
  61. ... {'params': no_conv_params, 'lr': 0.01}]
  62. >>> optim = nn.Momentum(group_params, learning_rate=0.1, momentum=0.9, weight_decay=0.0)
  63. >>> conv_lr = optim.get_lr_parameter(conv_params)
  64. >>> print(conv_lr[0].asnumpy())
  65. 0.05
  66. gradients_centralization(gradients)
  67. ݶĻ
  68. һŻѧϰģѵٶȵķ̳ :class:`mindspore.nn.Optimizer`ԶŻʱɵøýӿڽݶĻ
  69. gradients (tuple[Tensor]): ݶȣ״shapeͬ
  70. أ
  71. tuple[Tensor]ݶĻݶȡ
  72. scale_grad(gradients)
  73. ڻϾȳԭݶȡ
  74. ̳:class:`mindspore.nn.Optimizer`ԶŻʱɵøýӿڻԭݶȡ
  75. gradients (tuple[Tensor]): ݶȣ״shapeͬ
  76. أ
  77. tuple[Tensor]ԭݶȡ
  78. target
  79. ָhostϻ豸deviceϸ²Ϊstrֻ'CPU''Ascend''GPU'
  80. unique
  81. ԱʾǷŻнݶȥأͨϡ硣ݶϡΪTrueǰϡѶȨȥأݶdzܵģΪFalseδʱĬֵΪTrue