You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.train.callback.LearningRateScheduler.txt 1.3 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536
  1. Class mindspore.train.callback.LearningRateScheduler(learning_rate_function)
  2. 在训练期间更改学习率。
  3. 参数:
  4. learning_rate_function (Function):在训练期间更改学习率的函数。
  5. 示例:
  6. >>> from mindspore import Model
  7. >>> from mindspore.train.callback import LearningRateScheduler
  8. >>> import mindspore.nn as nn
  9. ...
  10. >>> def learning_rate_function(lr, cur_step_num):
  11. ... if cur_step_num%1000 == 0:
  12. ... lr = lr*0.1
  13. ... return lr
  14. ...
  15. >>> lr = 0.1
  16. >>> momentum = 0.9
  17. >>> net = Net()
  18. >>> loss = nn.SoftmaxCrossEntropyWithLogits()
  19. >>> optim = nn.Momentum(net.trainable_params(), learning_rate=lr, momentum=momentum)
  20. >>> model = Model(net, loss_fn=loss, optimizer=optim)
  21. ...
  22. >>> dataset = create_custom_dataset("custom_dataset_path")
  23. >>> model.train(1, dataset, callbacks=[LearningRateScheduler(learning_rate_function)],
  24. ... dataset_sink_mode=False)
  25. step_end(run_context)
  26. 在step结束时更改学习率。
  27. 参数:
  28. run_context (RunContext):包含模型的一些基本信息。