| @@ -443,12 +443,17 @@ class Validator: | |||
| @staticmethod | |||
| def check_file_name_by_regular(target, reg=None, flag=re.ASCII, prim_name=None): | |||
| """Check whether file name is legitimate.""" | |||
| if not isinstance(target, str): | |||
| raise ValueError("Args file_name {} must be string, please check it".format(target)) | |||
| if target.endswith("\\") or target.endswith("/"): | |||
| raise ValueError("File name cannot be a directory path.") | |||
| if reg is None: | |||
| reg = r"^[0-9a-zA-Z\_\-\.\:\/\\]+$" | |||
| if re.match(reg, target, flag) is None: | |||
| prim_name = f'in `{prim_name}`' if prim_name else "" | |||
| raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format( | |||
| target, prim_name, reg, flag)) | |||
| return True | |||
| @staticmethod | |||
| @@ -26,6 +26,7 @@ from mindspore._checkparam import Validator | |||
| from mindspore.train._utils import _make_directory | |||
| from mindspore.train.serialization import save_checkpoint, _save_graph | |||
| from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank | |||
| from mindspore.parallel._cell_wrapper import destroy_allgather_cell | |||
| from ._callback import Callback, set_cur_net | |||
| from ...common.tensor import Tensor | |||
| @@ -33,17 +34,6 @@ _cur_dir = os.getcwd() | |||
| _save_dir = _cur_dir | |||
| def _check_file_name_prefix(file_name_prefix): | |||
| """ | |||
| Check file name valid or not. | |||
| File name can't include '/'. This file name naming convention only apply to Linux. | |||
| """ | |||
| if not isinstance(file_name_prefix, str) or file_name_prefix.find('/') >= 0: | |||
| return False | |||
| return True | |||
| def _chg_ckpt_file_name_if_same_exist(directory, prefix): | |||
| """Check if there is a file with the same name.""" | |||
| files = os.listdir(directory) | |||
| @@ -245,11 +235,8 @@ class ModelCheckpoint(Callback): | |||
| self._last_time_for_keep = time.time() | |||
| self._last_triggered_step = 0 | |||
| if _check_file_name_prefix(prefix): | |||
| self._prefix = prefix | |||
| else: | |||
| raise ValueError("Prefix {} for checkpoint file name invalid, " | |||
| "please check and correct it and then continue.".format(prefix)) | |||
| Validator.check_file_name_by_regular(prefix) | |||
| self._prefix = prefix | |||
| if directory is not None: | |||
| self._directory = _make_directory(directory) | |||
| @@ -310,7 +297,6 @@ class ModelCheckpoint(Callback): | |||
| if thread.getName() == "asyn_save_ckpt": | |||
| thread.join() | |||
| from mindspore.parallel._cell_wrapper import destroy_allgather_cell | |||
| destroy_allgather_cell() | |||
| def _check_save_ckpt(self, cb_params, force_to_save): | |||
| @@ -28,7 +28,8 @@ from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.nn.optim import Momentum | |||
| from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \ | |||
| _CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op | |||
| from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist | |||
| from mindspore.train.callback._checkpoint import _chg_ckpt_file_name_if_same_exist | |||
| class Net(nn.Cell): | |||
| """Net definition.""" | |||
| @@ -150,32 +151,6 @@ def test_loss_monitor_normal_mode(): | |||
| loss_cb.end(run_context) | |||
| def test_check_file_name_not_str(): | |||
| """Test check file name not str.""" | |||
| ret = _check_file_name_prefix(1) | |||
| assert not ret | |||
| def test_check_file_name_back_err(): | |||
| """Test check file name back err.""" | |||
| ret = _check_file_name_prefix('abc.') | |||
| assert ret | |||
| def test_check_file_name_one_alpha(): | |||
| """Test check file name one alpha.""" | |||
| ret = _check_file_name_prefix('a') | |||
| assert ret | |||
| ret = _check_file_name_prefix('_') | |||
| assert ret | |||
| def test_check_file_name_err(): | |||
| """Test check file name err.""" | |||
| ret = _check_file_name_prefix('_123') | |||
| assert ret | |||
| def test_chg_ckpt_file_name_if_same_exist(): | |||
| """Test chg ckpt file name if same exist.""" | |||
| _chg_ckpt_file_name_if_same_exist(directory="./test_files", prefix="ckpt") | |||