|
|
|
@@ -228,6 +228,8 @@ class ModelCheckpoint(Callback): |
|
|
|
Args: |
|
|
|
run_context (RunContext): Context of the train running. |
|
|
|
""" |
|
|
|
if _is_role_pserver(): |
|
|
|
self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix |
|
|
|
cb_params = run_context.original_args() |
|
|
|
# save graph (only once) |
|
|
|
if not self._graph_saved: |
|
|
|
@@ -281,8 +283,6 @@ class ModelCheckpoint(Callback): |
|
|
|
if save_ckpt: |
|
|
|
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \ |
|
|
|
+ str(step_num_in_epoch) + ".ckpt" |
|
|
|
if _is_role_pserver(): |
|
|
|
cur_ckpoint_file = "PServer_" + str(_get_ps_mode_rank()) + "_" + cur_ckpoint_file |
|
|
|
# update checkpoint file list. |
|
|
|
self._manager.update_ckpoint_filelist(self._directory, self._prefix) |
|
|
|
# keep checkpoint files number equal max number. |
|
|
|
|