Browse Source

add type check and error description

tags/v1.0.0
liuyang_655 5 years ago
parent
commit
f4c32bc93e
1 changed files with 7 additions and 3 deletions
  1. +7
    -3
      mindspore/train/serialization.py

+ 7
- 3
mindspore/train/serialization.py View File

@@ -148,18 +148,22 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
Args:
save_obj (nn.Cell or list): The cell object or data list(each element is a dictionary, like
[{"name": param_name, "data": param_data},...], the type of param_name would
be string, and the type of param_data would be parameter, tensor or numpy).
be string, and the type of param_data would be parameter or tensor).
ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False

Raises:
TypeError: If the parameter save_obj is not nn.Cell or list type.
RuntimeError: Failed to save the Checkpoint file.
TypeError: If the parameter save_obj is not nn.Cell or list type.And if the parameter integrated_save and
async_save are not bool type.
"""

if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj)))
if not isinstance(integrated_save, bool):
raise TypeError("The parameter integrated_save should be bool, but got {}".format(type(integrated_save)))
if not isinstance(async_save, bool):
raise TypeError("The parameter async_save should be bool, but got {}".format(type(async_save)))

logger.info("Execute save checkpoint process.")



Loading…
Cancel
Save