| @@ -443,12 +443,17 @@ class Validator: | |||||
| @staticmethod | @staticmethod | ||||
| def check_file_name_by_regular(target, reg=None, flag=re.ASCII, prim_name=None): | def check_file_name_by_regular(target, reg=None, flag=re.ASCII, prim_name=None): | ||||
| """Check whether file name is legitimate.""" | """Check whether file name is legitimate.""" | ||||
| if not isinstance(target, str): | |||||
| raise ValueError("Args file_name {} must be string, please check it".format(target)) | |||||
| if target.endswith("\\") or target.endswith("/"): | |||||
| raise ValueError("File name cannot be a directory path.") | |||||
| if reg is None: | if reg is None: | ||||
| reg = r"^[0-9a-zA-Z\_\-\.\:\/\\]+$" | reg = r"^[0-9a-zA-Z\_\-\.\:\/\\]+$" | ||||
| if re.match(reg, target, flag) is None: | if re.match(reg, target, flag) is None: | ||||
| prim_name = f'in `{prim_name}`' if prim_name else "" | prim_name = f'in `{prim_name}`' if prim_name else "" | ||||
| raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format( | raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format( | ||||
| target, prim_name, reg, flag)) | target, prim_name, reg, flag)) | ||||
| return True | return True | ||||
| @staticmethod | @staticmethod | ||||
| @@ -26,6 +26,7 @@ from mindspore._checkparam import Validator | |||||
| from mindspore.train._utils import _make_directory | from mindspore.train._utils import _make_directory | ||||
| from mindspore.train.serialization import save_checkpoint, _save_graph | from mindspore.train.serialization import save_checkpoint, _save_graph | ||||
| from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank | from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank | ||||
| from mindspore.parallel._cell_wrapper import destroy_allgather_cell | |||||
| from ._callback import Callback, set_cur_net | from ._callback import Callback, set_cur_net | ||||
| from ...common.tensor import Tensor | from ...common.tensor import Tensor | ||||
| @@ -33,17 +34,6 @@ _cur_dir = os.getcwd() | |||||
| _save_dir = _cur_dir | _save_dir = _cur_dir | ||||
| def _check_file_name_prefix(file_name_prefix): | |||||
| """ | |||||
| Check file name valid or not. | |||||
| File name can't include '/'. This file name naming convention only apply to Linux. | |||||
| """ | |||||
| if not isinstance(file_name_prefix, str) or file_name_prefix.find('/') >= 0: | |||||
| return False | |||||
| return True | |||||
| def _chg_ckpt_file_name_if_same_exist(directory, prefix): | def _chg_ckpt_file_name_if_same_exist(directory, prefix): | ||||
| """Check if there is a file with the same name.""" | """Check if there is a file with the same name.""" | ||||
| files = os.listdir(directory) | files = os.listdir(directory) | ||||
| @@ -245,11 +235,8 @@ class ModelCheckpoint(Callback): | |||||
| self._last_time_for_keep = time.time() | self._last_time_for_keep = time.time() | ||||
| self._last_triggered_step = 0 | self._last_triggered_step = 0 | ||||
| if _check_file_name_prefix(prefix): | |||||
| self._prefix = prefix | |||||
| else: | |||||
| raise ValueError("Prefix {} for checkpoint file name invalid, " | |||||
| "please check and correct it and then continue.".format(prefix)) | |||||
| Validator.check_file_name_by_regular(prefix) | |||||
| self._prefix = prefix | |||||
| if directory is not None: | if directory is not None: | ||||
| self._directory = _make_directory(directory) | self._directory = _make_directory(directory) | ||||
| @@ -310,7 +297,6 @@ class ModelCheckpoint(Callback): | |||||
| if thread.getName() == "asyn_save_ckpt": | if thread.getName() == "asyn_save_ckpt": | ||||
| thread.join() | thread.join() | ||||
| from mindspore.parallel._cell_wrapper import destroy_allgather_cell | |||||
| destroy_allgather_cell() | destroy_allgather_cell() | ||||
| def _check_save_ckpt(self, cb_params, force_to_save): | def _check_save_ckpt(self, cb_params, force_to_save): | ||||
| @@ -28,7 +28,8 @@ from mindspore.nn import TrainOneStepCell, WithLossCell | |||||
| from mindspore.nn.optim import Momentum | from mindspore.nn.optim import Momentum | ||||
| from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \ | from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \ | ||||
| _CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op | _CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op | ||||
| from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist | |||||
| from mindspore.train.callback._checkpoint import _chg_ckpt_file_name_if_same_exist | |||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| """Net definition.""" | """Net definition.""" | ||||
| @@ -150,32 +151,6 @@ def test_loss_monitor_normal_mode(): | |||||
| loss_cb.end(run_context) | loss_cb.end(run_context) | ||||
| def test_check_file_name_not_str(): | |||||
| """Test check file name not str.""" | |||||
| ret = _check_file_name_prefix(1) | |||||
| assert not ret | |||||
| def test_check_file_name_back_err(): | |||||
| """Test check file name back err.""" | |||||
| ret = _check_file_name_prefix('abc.') | |||||
| assert ret | |||||
| def test_check_file_name_one_alpha(): | |||||
| """Test check file name one alpha.""" | |||||
| ret = _check_file_name_prefix('a') | |||||
| assert ret | |||||
| ret = _check_file_name_prefix('_') | |||||
| assert ret | |||||
| def test_check_file_name_err(): | |||||
| """Test check file name err.""" | |||||
| ret = _check_file_name_prefix('_123') | |||||
| assert ret | |||||
| def test_chg_ckpt_file_name_if_same_exist(): | def test_chg_ckpt_file_name_if_same_exist(): | ||||
| """Test chg ckpt file name if same exist.""" | """Test chg ckpt file name if same exist.""" | ||||
| _chg_ckpt_file_name_if_same_exist(directory="./test_files", prefix="ckpt") | _chg_ckpt_file_name_if_same_exist(directory="./test_files", prefix="ckpt") | ||||