You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.load_checkpoint.rst 1.7 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738
  1. mindspore.load_checkpoint
  2. ==========================
  3. .. py:class:: mindspore.load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM")
  4. 加载checkpoint文件。
  5. **参数:**
  6. - **ckpt_file_name** (str) – checkpoint的文件名称。
  7. - **net** (Cell) – 加载checkpoint参数的网络。默认值:None。
  8. - **strict_load** (bool) – 是否将严格加载参数到网络中。如果是False, 它将根据相同的后缀名将参数字典中的参数加载到网络中,并会在精度不匹配时,进行强制精度转换,比如将float32转换为float16。默认值:False。
  9. - **filter_prefix** (Union[str, list[str], tuple[str]]) – 以filter_prefix开头的参数将不会被加载。默认值:None。
  10. - **dec_key** (Union[None, bytes]) – 用于解密的字节类型密钥,如果值为None,则不需要解密。默认值:None。
  11. - **dec_mode** (str) – 该参数仅当dec_key不为None时有效。指定解密模式,目前支持“AES-GCM”和“AES-CBC”。默认值:“AES-GCM”。
  12. **返回:**
  13. 字典,key是参数名称,value是Parameter类型。
  14. **异常:**
  15. **ValueError** – checkpoint文件格式正确。
  16. **样例:**
  17. .. code-block::
  18. >>> from mindspore import load_checkpoint
  19. >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
  20. >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
  21. >>> print(param_dict["conv2.weight"])
  22. Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)