You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.train.callback.CheckpointConfig.rst 4.6 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. .. py:class:: mindspore.train.callback.CheckpointConfig(save_checkpoint_steps=1, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, integrated_save=True, async_save=False, saved_network=None, append_info=None, enc_key=None, enc_mode='AES-GCM', exception_save=False)
  2. 保存checkpoint时的配置策略。
  3. .. note::
  4. 在训练过程中,如果数据集是通过数据通道传输的,建议将 `save_checkpoint_steps` 设为循环下沉step数量的整数倍数,否则,保存checkpoint的时机可能会有偏差。建议同时只设置一种触发保存checkpoint策略和一种保留checkpoint文件总数策略。如果同时设置了 `save_checkpoint_steps` 和 `save_checkpoint_seconds` ,则 `save_checkpoint_seconds` 无效。如果同时设置了 `keep_checkpoint_max` 和 `keep_checkpoint_per_n_minutes` ,则 `keep_checkpoint_per_n_minutes` 无效。
  5. **参数:**
  6. - **save_checkpoint_steps** (int) - 每隔多少个step保存一次checkpoint。默认值:1。
  7. - **save_checkpoint_seconds** (int) - 每隔多少秒保存一次checkpoint。不能同时与 `save_checkpoint_steps` 一起使用。默认值:0。
  8. - **keep_checkpoint_max** (int) - 最多保存多少个checkpoint文件。默认值:5。
  9. - **keep_checkpoint_per_n_minutes** (int) - 每隔多少分钟保存一个checkpoint文件。不能同时与 `keep_checkpoint_max` 一起使用。默认值:0。
  10. - **integrated_save** (bool) - 在自动并行场景下,是否合并保存拆分后的Tensor。合并保存功能仅支持在自动并行场景中使用,在手动并行场景中不支持。默认值:True。
  11. - **async_save** (bool) - 是否异步执行保存checkpoint文件。默认值:False。
  12. - **saved_network** (Cell) - 保存在checkpoint文件中的网络。如果 `saved_network` 没有被训练,则保存 `saved_network` 的初始值。默认值:None。
  13. - **append_info** (list) - 保存在checkpoint文件中的信息。支持"epoch_num"、"step_num"和dict类型。dict的key必须是str,dict的value必须是int、float或bool中的一个。默认值:None。
  14. - **enc_key** (Union[None, bytes]) - 用于加密的字节类型key。如果值为None,则不需要加密。默认值:None。
  15. - **enc_mode** (str) - 仅当 `enc_key` 不设为None时,该参数有效。指定了加密模式,目前支持AES-GCM和AES-CBC。默认值:AES-GCM。
  16. - **exception_save** (bool) - 当有异常发生时,是否保存当前checkpoint文件。默认值:False。
  17. **异常:**
  18. - **ValueError** - 输入参数的类型不正确。
  19. .. py:method:: append_dict
  20. :property:
  21. 获取需要额外保存到checkpoint中的字典的值。
  22. **返回:**
  23. Dict: 字典中的值。
  24. .. py:method:: async_save
  25. :property:
  26. 获取是否异步保存checkpoint。
  27. **返回:**
  28. Bool: 是否异步保存checkpoint。
  29. .. py:method:: enc_key
  30. :property:
  31. 获取加密的key值。
  32. **返回:**
  33. (None, bytes): 加密的key值。
  34. .. py:method:: enc_mode
  35. :property:
  36. 获取加密模式。
  37. **返回:**
  38. str: 加密模式。
  39. .. py:method:: get_checkpoint_policy()
  40. 获取checkpoint的保存策略。
  41. **返回:**
  42. Dict: checkpoint的保存策略。
  43. .. py:method:: integrated_save
  44. :property:
  45. 获取是否合并保存拆分后的Tensor。
  46. **返回:**
  47. Bool: 获取是否合并保存拆分后的Tensor。
  48. .. py:method:: keep_checkpoint_max
  49. :property:
  50. 获取最多保存checkpoint文件的数量。
  51. **返回:**
  52. Int: 最多保存checkpoint文件的数量。
  53. .. py:method:: keep_checkpoint_per_n_minutes
  54. :property:
  55. 获取每隔多少分钟保存一个checkpoint文件。
  56. **返回:**
  57. Int: 每隔多少分钟保存一个checkpoint文件。
  58. .. py:method:: save_checkpoint_seconds
  59. :property:
  60. 获取每隔多少秒保存一次checkpoint文件。
  61. **返回:**
  62. Int: 每隔多少秒保存一次checkpoint文件。
  63. .. py:method:: save_checkpoint_steps
  64. :property:
  65. 获取每隔多少个step保存一次checkpoint文件。
  66. **返回:**
  67. Int: 每隔多少个step保存一次checkpoint文件。
  68. .. py:method:: saved_network
  69. :property:
  70. 获取需要保存的网络。
  71. **返回:**
  72. Cell: 需要保存的网络。