Browse Source

fix checkpoint cache and prosess exception

tags/v1.2.0-rc1
fangzehua 4 years ago
parent
commit
da85c804bc
2 changed files with 4 additions and 3 deletions
  1. +1
    -0
      mindspore/ccsrc/backend/session/executor.cc
  2. +3
    -3
      mindspore/train/callback/_checkpoint.py

+ 1
- 0
mindspore/ccsrc/backend/session/executor.cc View File

@@ -349,6 +349,7 @@ void Executor::WaitTaskGraphAvailable(const SessionPtr &session, const std::shar
mindspore::ScopedLongRunning long_running; mindspore::ScopedLongRunning long_running;
for (auto &tensor : task->input_tensors_) { for (auto &tensor : task->input_tensors_) {
if (tensor->NeedWait() && !tensor->IsGraphOutput()) { if (tensor->NeedWait() && !tensor->IsGraphOutput()) {
MsException::Instance().CheckException();
tensor->Wait(); tensor->Wait();
} }
} }


+ 3
- 3
mindspore/train/callback/_checkpoint.py View File

@@ -296,9 +296,6 @@ class ModelCheckpoint(Callback):
cb_params = run_context.original_args() cb_params = run_context.original_args()
_to_save_last_ckpt = True _to_save_last_ckpt = True


# if param is cache enable, flush data from cache to host before epoch end
self._flush_from_cache(cb_params)

self._save_ckpt(cb_params, _to_save_last_ckpt) self._save_ckpt(cb_params, _to_save_last_ckpt)


thread_list = threading.enumerate() thread_list = threading.enumerate()
@@ -328,6 +325,9 @@ class ModelCheckpoint(Callback):
if cb_params.cur_step_num == self._last_triggered_step: if cb_params.cur_step_num == self._last_triggered_step:
return return


# if param is cache enable, flush data from cache to host before save_ckpt
self._flush_from_cache(cb_params)

save_ckpt = self._check_save_ckpt(cb_params, force_to_save) save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1) step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)




Loading…
Cancel
Save