|
|
|
@@ -18,6 +18,7 @@ import os |
|
|
|
import stat |
|
|
|
import time |
|
|
|
|
|
|
|
import threading |
|
|
|
import mindspore.context as context |
|
|
|
from mindspore import log as logger |
|
|
|
from mindspore._checkparam import check_bool, check_int_non_negative |
|
|
|
@@ -245,6 +246,12 @@ class ModelCheckpoint(Callback): |
|
|
|
_to_save_last_ckpt = True |
|
|
|
self._save_ckpt(cb_params, _to_save_last_ckpt) |
|
|
|
|
|
|
|
thread_list = threading.enumerate() |
|
|
|
if len(thread_list) > 1: |
|
|
|
for thread in thread_list: |
|
|
|
if thread.getName() == "asyn_save_ckpt": |
|
|
|
thread.join() |
|
|
|
|
|
|
|
from mindspore.parallel._cell_wrapper import destroy_allgather_cell |
|
|
|
destroy_allgather_cell() |
|
|
|
|
|
|
|
|