Browse Source

modify asyn save checkpoint bug

tags/v0.6.0-beta
changzherui 5 years ago
parent
commit
99a2ab4b2e
2 changed files with 8 additions and 1 deletions
  1. +7
    -0
      mindspore/train/callback/_checkpoint.py
  2. +1
    -1
      mindspore/train/serialization.py

+ 7
- 0
mindspore/train/callback/_checkpoint.py View File

@@ -18,6 +18,7 @@ import os
import stat
import time

import threading
import mindspore.context as context
from mindspore import log as logger
from mindspore._checkparam import check_bool, check_int_non_negative
@@ -245,6 +246,12 @@ class ModelCheckpoint(Callback):
_to_save_last_ckpt = True
self._save_ckpt(cb_params, _to_save_last_ckpt)

thread_list = threading.enumerate()
if len(thread_list) > 1:
for thread in thread_list:
if thread.getName() == "asyn_save_ckpt":
thread.join()

from mindspore.parallel._cell_wrapper import destroy_allgather_cell
destroy_allgather_cell()



+ 1
- 1
mindspore/train/serialization.py View File

@@ -160,7 +160,7 @@ def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
data_list[key].append(data)

if async_save:
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list))
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list), name="asyn_save_ckpt")
thr.start()
else:
_exec_save(ckpt_file_name, data_list)


Loading…
Cancel
Save