You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.save_checkpoint.rst 1.8 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435
  1. mindspore.save_checkpoint
  2. =========================
  3. .. py:class:: mindspore.save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM")
  4. 将网络权重保存到checkpoint文件中。
  5. **参数:**
  6. - **save_obj** (Union[Cell, list]) – Cell对象或者数据列表(列表的每个元素为字典类型,比如[{"name": param_name, “data”: param_data},…],*param_name* 的类型必须是str,*param_data* 的类型必须是Parameter或者Tensor)。
  7. - **ckpt_file_name** (str) – checkpoint文件名称。如果文件已存在,将会覆盖原有文件。
  8. - **integrated_save** (bool) – 在并行场景下是否合并保存拆分的Tensor。默认值:True。
  9. - **async_save** (bool) – 是否异步执行保存checkpoint文件。默认值:False。
  10. - **append_dict** (dict) – 需要保存的其他信息。dict的键必须为str类型,dict的值类型必须是float或者bool类型。默认值:None。
  11. - **enc_key** (Union[None, bytes]) – 用于加密的字节类型密钥。如果值为None,那么不需要加密。默认值:None。
  12. - **enc_mode** (str) – 该参数在 *enc_key* 不为None时有效,指定加密模式,目前仅支持"AES-GCM"和"AES-CBC"。 默认值:“AES-GCM”。
  13. **异常:**
  14. **TypeError** – 如果参数 *save_obj* 类型不为nn.Cell或者list,且如果参数 *integrated_save* 及 *async_save* 非bool类型。
  15. **样例:**
  16. .. code-block::
  17. >>> from mindspore import save_checkpoint
  18. >>>
  19. >>> net = Net()
  20. >>> save_checkpoint(net, "lenet.ckpt")