You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.train.callback.LearningRateScheduler.rst 1.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. mindspore.train.callback.LearningRateScheduler
  2. ===============================================
  3. .. py:class:: mindspore.train.callback.LearningRateScheduler(learning_rate_function)
  4. 在训练期间更改学习率。
  5. **参数:**
  6. **learning_rate_function** (Function) - 在训练期间更改学习率的函数。
  7. **样例:**
  8. >>> from mindspore import Model
  9. >>> from mindspore.train.callback import LearningRateScheduler
  10. >>> import mindspore.nn as nn
  11. ...
  12. >>> def learning_rate_function(lr, cur_step_num):
  13. ... if cur_step_num%1000 == 0:
  14. ... lr = lr*0.1
  15. ... return lr
  16. ...
  17. >>> lr = 0.1
  18. >>> momentum = 0.9
  19. >>> net = Net()
  20. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  21. >>> optim = nn.Momentum(net.trainable_params(), learning_rate=lr, momentum=momentum)
  22. >>> model = Model(net, loss_fn=loss, optimizer=optim)
  23. ...
  24. >>> dataset = create_custom_dataset("custom_dataset_path")
  25. >>> model.train(1, dataset, callbacks=[LearningRateScheduler(learning_rate_function)],
  26. ... dataset_sink_mode=False)
  27. .. py:method:: step_end(run_context)
  28. 在step结束时更改学习率。
  29. **参数:**
  30. **run_context** (RunContext) - 包含模型的一些基本信息。