|
|
|
@@ -244,6 +244,10 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, |
|
|
|
if not isinstance(ckpt_file_name, str): |
|
|
|
raise TypeError("The argument {} for checkpoint file name is invalid, 'ckpt_file_name' must be " |
|
|
|
"string, but got {}.".format(ckpt_file_name, type(ckpt_file_name))) |
|
|
|
ckpt_file_name = os.path.realpath(ckpt_file_name) |
|
|
|
if os.path.isdir(ckpt_file_name): |
|
|
|
raise IsADirectoryError("The argument `ckpt_file_name`: {} is a directory, " |
|
|
|
"it should be a file name.".format(ckpt_file_name)) |
|
|
|
if not ckpt_file_name.endswith('.ckpt'): |
|
|
|
ckpt_file_name += ".ckpt" |
|
|
|
integrated_save = Validator.check_bool(integrated_save) |
|
|
|
@@ -298,7 +302,6 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, |
|
|
|
data = param["data"].asnumpy().reshape(-1) |
|
|
|
data_list[key].append(data) |
|
|
|
|
|
|
|
ckpt_file_name = os.path.realpath(ckpt_file_name) |
|
|
|
if async_save: |
|
|
|
data_copy = copy.deepcopy(data_list) |
|
|
|
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_copy, enc_key, enc_mode), name="asyn_save_ckpt") |
|
|
|
|