You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.train.callback.Callback.rst 2.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. mindspore.train.callback.Callback
  2. ===================================
  3. .. py:class:: mindspore.train.callback.Callback
  4. 用于构建回调函数的基类。回调函数是一个上下文管理器,在运行模型时被调用。
  5. 可以使用此机制进行初始化和释放资源等操作。
  6. 回调函数可以在step或epoch中的执行一些操作。
  7. 它保存模型相关信息。例如 `network` 、 `train_network` 、 `epoch_num` 、 `batch_num` 、 `loss_fn` 、 `optimizer` 、 `parallel_mode` 、 `device_number` 、 `list_callback` 、 `cur_epoch_num` 、 `cur_step_num` 、 `dataset_sink_mode` 、 `net_outputs` 等。
  8. **样例:**
  9. >>> from mindspore import Model, nn
  10. >>> from mindspore.train.callback import Callback
  11. >>> class Print_info(Callback):
  12. ... def step_end(self, run_context):
  13. ... cb_params = run_context.original_args()
  14. ... print("step_num: ", cb_params.cur_step_num)
  15. >>>
  16. >>> print_cb = Print_info()
  17. >>> dataset = create_custom_dataset()
  18. >>> net = Net()
  19. >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
  20. >>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
  21. >>> model = Model(net, loss_fn=loss, optimizer=optim)
  22. >>> model.train(1, dataset, callbacks=print_cb)
  23. step_num:1
  24. .. py:method:: begin(run_context)
  25. 在网络执行之前被调用一次。
  26. **参数:**
  27. **run_context** (RunContext) - 包含模型的一些基本信息。
  28. .. py:method:: end(run_context)
  29. 网络执行后被调用一次。
  30. **参数:**
  31. **run_context** (RunContext) - 包含模型的一些基本信息。
  32. .. py:method:: epoch_begin(run_context)
  33. 在每个epoch开始之前被调用。
  34. **参数:**
  35. **run_context** (RunContext) - 包含模型的一些基本信息。
  36. .. py:method:: epoch_end(run_context)
  37. 在每个epoch结束后被调用。
  38. **参数:**
  39. **run_context** (RunContext) - 包含模型的一些基本信息。
  40. .. py:method:: step_begin(run_context)
  41. 在每个step开始之前被调用。
  42. **参数:**
  43. **run_context** (RunContext) - 包含模型的一些基本信息。
  44. .. py:method:: step_end(run_context)
  45. 在每个step完成后被调用。
  46. **参数:**
  47. **run_context** (RunContext) - 包含模型的一些基本信息。