Browse Source

modify check file name bug

pull/15553/head
changzherui 4 years ago
parent
commit
f63a13b2e9
3 changed files with 10 additions and 44 deletions
  1. +5
    -0
      mindspore/_checkparam.py
  2. +3
    -17
      mindspore/train/callback/_checkpoint.py
  3. +2
    -27
      tests/ut/python/utils/test_callback.py

+ 5
- 0
mindspore/_checkparam.py View File

@@ -443,12 +443,17 @@ class Validator:
@staticmethod @staticmethod
def check_file_name_by_regular(target, reg=None, flag=re.ASCII, prim_name=None): def check_file_name_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
"""Check whether file name is legitimate.""" """Check whether file name is legitimate."""
if not isinstance(target, str):
raise ValueError("Args file_name {} must be string, please check it".format(target))
if target.endswith("\\") or target.endswith("/"):
raise ValueError("File name cannot be a directory path.")
if reg is None: if reg is None:
reg = r"^[0-9a-zA-Z\_\-\.\:\/\\]+$" reg = r"^[0-9a-zA-Z\_\-\.\:\/\\]+$"
if re.match(reg, target, flag) is None: if re.match(reg, target, flag) is None:
prim_name = f'in `{prim_name}`' if prim_name else "" prim_name = f'in `{prim_name}`' if prim_name else ""
raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format( raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format(
target, prim_name, reg, flag)) target, prim_name, reg, flag))

return True return True


@staticmethod @staticmethod


+ 3
- 17
mindspore/train/callback/_checkpoint.py View File

@@ -26,6 +26,7 @@ from mindspore._checkparam import Validator
from mindspore.train._utils import _make_directory from mindspore.train._utils import _make_directory
from mindspore.train.serialization import save_checkpoint, _save_graph from mindspore.train.serialization import save_checkpoint, _save_graph
from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank from mindspore.parallel._ps_context import _is_role_pserver, _get_ps_mode_rank
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
from ._callback import Callback, set_cur_net from ._callback import Callback, set_cur_net
from ...common.tensor import Tensor from ...common.tensor import Tensor


@@ -33,17 +34,6 @@ _cur_dir = os.getcwd()
_save_dir = _cur_dir _save_dir = _cur_dir




def _check_file_name_prefix(file_name_prefix):
"""
Check file name valid or not.

File name can't include '/'. This file name naming convention only apply to Linux.
"""
if not isinstance(file_name_prefix, str) or file_name_prefix.find('/') >= 0:
return False
return True


def _chg_ckpt_file_name_if_same_exist(directory, prefix): def _chg_ckpt_file_name_if_same_exist(directory, prefix):
"""Check if there is a file with the same name.""" """Check if there is a file with the same name."""
files = os.listdir(directory) files = os.listdir(directory)
@@ -245,11 +235,8 @@ class ModelCheckpoint(Callback):
self._last_time_for_keep = time.time() self._last_time_for_keep = time.time()
self._last_triggered_step = 0 self._last_triggered_step = 0


if _check_file_name_prefix(prefix):
self._prefix = prefix
else:
raise ValueError("Prefix {} for checkpoint file name invalid, "
"please check and correct it and then continue.".format(prefix))
Validator.check_file_name_by_regular(prefix)
self._prefix = prefix


if directory is not None: if directory is not None:
self._directory = _make_directory(directory) self._directory = _make_directory(directory)
@@ -310,7 +297,6 @@ class ModelCheckpoint(Callback):
if thread.getName() == "asyn_save_ckpt": if thread.getName() == "asyn_save_ckpt":
thread.join() thread.join()


from mindspore.parallel._cell_wrapper import destroy_allgather_cell
destroy_allgather_cell() destroy_allgather_cell()


def _check_save_ckpt(self, cb_params, force_to_save): def _check_save_ckpt(self, cb_params, force_to_save):


+ 2
- 27
tests/ut/python/utils/test_callback.py View File

@@ -28,7 +28,8 @@ from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum from mindspore.nn.optim import Momentum
from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \ from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \
_CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op _CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist
from mindspore.train.callback._checkpoint import _chg_ckpt_file_name_if_same_exist



class Net(nn.Cell): class Net(nn.Cell):
"""Net definition.""" """Net definition."""
@@ -150,32 +151,6 @@ def test_loss_monitor_normal_mode():
loss_cb.end(run_context) loss_cb.end(run_context)




def test_check_file_name_not_str():
"""Test check file name not str."""
ret = _check_file_name_prefix(1)
assert not ret


def test_check_file_name_back_err():
"""Test check file name back err."""
ret = _check_file_name_prefix('abc.')
assert ret


def test_check_file_name_one_alpha():
"""Test check file name one alpha."""
ret = _check_file_name_prefix('a')
assert ret
ret = _check_file_name_prefix('_')
assert ret


def test_check_file_name_err():
"""Test check file name err."""
ret = _check_file_name_prefix('_123')
assert ret


def test_chg_ckpt_file_name_if_same_exist(): def test_chg_ckpt_file_name_if_same_exist():
"""Test chg ckpt file name if same exist.""" """Test chg ckpt file name if same exist."""
_chg_ckpt_file_name_if_same_exist(directory="./test_files", prefix="ckpt") _chg_ckpt_file_name_if_same_exist(directory="./test_files", prefix="ckpt")


Loading…
Cancel
Save