You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.train.callback.Callback.txt 2.4 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. Class mindspore.train.callback.Callback
  2. 用于构建回调函数的基类。回调函数是一个上下文管理器,在运行模型时被调用。
  3. 可以使用此机制进行初始化和释放资源等操作。
  4. 回调函数可以在step或epoch中的执行一些操作。
  5. 它保存模型相关信息。例如`network`、`train_network`、`epoch_num`、`batch_num`、`loss_fn`、`optimizer`、`parallel_mode`、`device_number`、`list_callback`、`cur_epoch_num`、`cur_step_num`、`dataset_sink_mode`、`net_outputs`等。
  6. 示例:
  7. >>> from mindspore import Model, nn
  8. >>> from mindspore.train.callback import Callback
  9. >>> class Print_info(Callback):
  10. ... def step_end(self, run_context):
  11. ... cb_params = run_context.original_args()
  12. ... print("step_num: ", cb_params.cur_step_num)
  13. >>>
  14. >>> print_cb = Print_info()
  15. >>> dataset = create_custom_dataset()
  16. >>> net = Net()
  17. >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
  18. >>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
  19. >>> model = Model(net, loss_fn=loss, optimizer=optim)
  20. >>> model.train(1, dataset, callbacks=print_cb)
  21. step_num:1
  22. begin(run_context)
  23. 在网络执行之前被调用一次。
  24. 参数:
  25. run_context (RunContext):包含模型的一些基本信息。
  26. end(run_context)
  27. 网络执行后被调用一次。
  28. 参数:
  29. run_context (RunContext):包含模型的一些基本信息。
  30. epoch_begin(run_context)
  31. 在每个epoch开始之前被调用。
  32. 参数:
  33. run_context (RunContext):包含模型的一些基本信息。
  34. epoch_end(run_context)
  35. 在每个epoch结束后被调用。
  36. 参数:
  37. run_context (RunContext):包含模型的一些基本信息。
  38. step_begin(run_context)
  39. 在每个step开始之前被调用。
  40. 参数:
  41. run_context (RunContext):包含模型的一些基本信息。
  42. step_end(run_context)
  43. 在每个step完成后被调用。
  44. 参数:
  45. run_context (RunContext):包含模型的一些基本信息。