|
|
|
@@ -280,7 +280,10 @@ class ModelCheckpoint(Callback): |
|
|
|
os.remove(graph_file_name) |
|
|
|
_save_graph(cb_params.train_network, graph_file_name) |
|
|
|
self._graph_saved = True |
|
|
|
|
|
|
|
thread_list = threading.enumerate() |
|
|
|
for thread in thread_list: |
|
|
|
if thread.getName() == "asyn_save_ckpt": |
|
|
|
thread.join() |
|
|
|
self._save_ckpt(cb_params) |
|
|
|
|
|
|
|
def end(self, run_context): |
|
|
|
@@ -295,10 +298,9 @@ class ModelCheckpoint(Callback): |
|
|
|
self._save_ckpt(cb_params, _to_save_last_ckpt) |
|
|
|
|
|
|
|
thread_list = threading.enumerate() |
|
|
|
if len(thread_list) > 1: |
|
|
|
for thread in thread_list: |
|
|
|
if thread.getName() == "asyn_save_ckpt": |
|
|
|
thread.join() |
|
|
|
for thread in thread_list: |
|
|
|
if thread.getName() == "asyn_save_ckpt": |
|
|
|
thread.join() |
|
|
|
|
|
|
|
from mindspore.parallel._cell_wrapper import destroy_allgather_cell |
|
|
|
destroy_allgather_cell() |
|
|
|
|