|
|
|
@@ -261,6 +261,7 @@ class ModelCheckpoint(Callback): |
|
|
|
self._manager = CheckpointManager() |
|
|
|
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix) |
|
|
|
self._graph_saved = False |
|
|
|
self._need_flush_from_cache = True |
|
|
|
|
|
|
|
def step_end(self, run_context): |
|
|
|
""" |
|
|
|
@@ -326,7 +327,8 @@ class ModelCheckpoint(Callback): |
|
|
|
return |
|
|
|
|
|
|
|
# if param is cache enable, flush data from cache to host before save_ckpt |
|
|
|
self._flush_from_cache(cb_params) |
|
|
|
if self._need_flush_from_cache: |
|
|
|
self._flush_from_cache(cb_params) |
|
|
|
|
|
|
|
save_ckpt = self._check_save_ckpt(cb_params, force_to_save) |
|
|
|
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1) |
|
|
|
@@ -365,10 +367,14 @@ class ModelCheckpoint(Callback): |
|
|
|
|
|
|
|
def _flush_from_cache(self, cb_params): |
|
|
|
"""Flush cache data to host if tensor is cache enable.""" |
|
|
|
has_cache_params = False |
|
|
|
params = cb_params.train_network.get_parameters() |
|
|
|
for param in params: |
|
|
|
if param.cache_enable: |
|
|
|
has_cache_params = True |
|
|
|
Tensor(param).flush_from_cache() |
|
|
|
if not has_cache_params: |
|
|
|
self._need_flush_from_cache = False |
|
|
|
|
|
|
|
@property |
|
|
|
def latest_ckpt_file_name(self): |
|
|
|
|