You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.train.callback.ModelCheckpoint.rst 1.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. mindspore.train.callback.ModelCheckpoint
  2. ==========================================
  3. .. py:class:: mindspore.train.callback.ModelCheckpoint(prefix='CKP', directory=None, config=None)
  4. checkpoint的回调函数。
  5. 在训练过程中调用该方法可以保存训练后的网络参数。
  6. .. note::
  7. 在分布式训练场景下,请为每个训练进程指定不同的目录来保存checkpoint文件。否则,可能会训练失败。
  8. **参数:**
  9. - **prefix** (str) - checkpoint文件的前缀名称。默认值:CKP。
  10. - **directory** (str) - 保存checkpoint文件的文件夹路径。默认情况下,文件保存在当前目录下。默认值:None。
  11. - **config** (CheckpointConfig) - checkpoint策略配置。默认值:None。
  12. **异常:**
  13. - **ValueError** - 如果前缀无效。
  14. - **TypeError** - config不是CheckpointConfig类型。
  15. .. py:method:: end(run_context)
  16. 在训练结束后,会保存最后一个step的checkpoint。
  17. **参数:**
  18. **run_context** (RunContext) - 包含模型的一些基本信息。
  19. .. py:method:: latest_ckpt_file_name
  20. :property:
  21. 返回最新的checkpoint路径和文件名。
  22. .. py:method:: step_end(run_context)
  23. 在step结束时保存checkpoint。
  24. **参数:**
  25. **run_context** (RunContext) - 包含模型的一些基本信息。