| @@ -148,18 +148,22 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F | |||||
| Args: | Args: | ||||
| save_obj (nn.Cell or list): The cell object or data list(each element is a dictionary, like | save_obj (nn.Cell or list): The cell object or data list(each element is a dictionary, like | ||||
| [{"name": param_name, "data": param_data},...], the type of param_name would | [{"name": param_name, "data": param_data},...], the type of param_name would | ||||
| be string, and the type of param_data would be parameter, tensor or numpy). | |||||
| be string, and the type of param_data would be parameter or tensor). | |||||
| ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten. | ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten. | ||||
| integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True | integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True | ||||
| async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False | async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False | ||||
| Raises: | Raises: | ||||
| TypeError: If the parameter save_obj is not nn.Cell or list type. | |||||
| RuntimeError: Failed to save the Checkpoint file. | |||||
| TypeError: If the parameter save_obj is not nn.Cell or list type.And if the parameter integrated_save and | |||||
| async_save are not bool type. | |||||
| """ | """ | ||||
| if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list): | if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list): | ||||
| raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj))) | raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj))) | ||||
| if not isinstance(integrated_save, bool): | |||||
| raise TypeError("The parameter integrated_save should be bool, but got {}".format(type(integrated_save))) | |||||
| if not isinstance(async_save, bool): | |||||
| raise TypeError("The parameter async_save should be bool, but got {}".format(type(async_save))) | |||||
| logger.info("Execute save checkpoint process.") | logger.info("Execute save checkpoint process.") | ||||