|
|
|
@@ -296,9 +296,6 @@ class ModelCheckpoint(Callback): |
|
|
|
cb_params = run_context.original_args() |
|
|
|
_to_save_last_ckpt = True |
|
|
|
|
|
|
|
# if param is cache enable, flush data from cache to host before epoch end |
|
|
|
self._flush_from_cache(cb_params) |
|
|
|
|
|
|
|
self._save_ckpt(cb_params, _to_save_last_ckpt) |
|
|
|
|
|
|
|
thread_list = threading.enumerate() |
|
|
|
@@ -328,6 +325,9 @@ class ModelCheckpoint(Callback): |
|
|
|
if cb_params.cur_step_num == self._last_triggered_step: |
|
|
|
return |
|
|
|
|
|
|
|
# if param is cache enable, flush data from cache to host before save_ckpt |
|
|
|
self._flush_from_cache(cb_params) |
|
|
|
|
|
|
|
save_ckpt = self._check_save_ckpt(cb_params, force_to_save) |
|
|
|
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1) |
|
|
|
|
|
|
|
|