|
|
|
@@ -26,6 +26,7 @@ from mindspore._checkparam import Validator |
|
|
|
from mindspore.train._utils import _make_directory |
|
|
|
from mindspore.train.serialization import save_checkpoint, _save_graph |
|
|
|
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank |
|
|
|
from mindspore.parallel._cell_wrapper import destroy_allgather_cell |
|
|
|
from ._callback import Callback, set_cur_net |
|
|
|
from ...common.tensor import Tensor |
|
|
|
|
|
|
|
@@ -33,17 +34,6 @@ _cur_dir = os.getcwd() |
|
|
|
_save_dir = _cur_dir |
|
|
|
|
|
|
|
|
|
|
|
def _check_file_name_prefix(file_name_prefix): |
|
|
|
""" |
|
|
|
Check file name valid or not. |
|
|
|
|
|
|
|
File name can't include '/'. This file name naming convention only apply to Linux. |
|
|
|
""" |
|
|
|
if not isinstance(file_name_prefix, str) or file_name_prefix.find('/') >= 0: |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def _chg_ckpt_file_name_if_same_exist(directory, prefix): |
|
|
|
"""Check if there is a file with the same name.""" |
|
|
|
files = os.listdir(directory) |
|
|
|
@@ -245,11 +235,8 @@ class ModelCheckpoint(Callback): |
|
|
|
self._last_time_for_keep = time.time() |
|
|
|
self._last_triggered_step = 0 |
|
|
|
|
|
|
|
if _check_file_name_prefix(prefix): |
|
|
|
self._prefix = prefix |
|
|
|
else: |
|
|
|
raise ValueError("Prefix {} for checkpoint file name invalid, " |
|
|
|
"please check and correct it and then continue.".format(prefix)) |
|
|
|
Validator.check_file_name_by_regular(prefix) |
|
|
|
self._prefix = prefix |
|
|
|
|
|
|
|
if directory is not None: |
|
|
|
self._directory = _make_directory(directory) |
|
|
|
@@ -310,7 +297,6 @@ class ModelCheckpoint(Callback): |
|
|
|
if thread.getName() == "asyn_save_ckpt": |
|
|
|
thread.join() |
|
|
|
|
|
|
|
from mindspore.parallel._cell_wrapper import destroy_allgather_cell |
|
|
|
destroy_allgather_cell() |
|
|
|
|
|
|
|
def _check_save_ckpt(self, cb_params, force_to_save): |
|
|
|
|