浏览代码

!4010 modify checkpoint config param check

Merge pull request !4010 from changzherui/mod_ckpt_param
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 年前
父节点
当前提交
768ff072ad
共有 1 个文件被更改,包括 5 次插入6 次删除
  1. +5
    -6
      mindspore/train/callback/_checkpoint.py

+ 5
- 6
mindspore/train/callback/_checkpoint.py 查看文件

@@ -104,10 +104,6 @@ class CheckpointConfig:
integrated_save=True,
async_save=False):

if not save_checkpoint_steps and not save_checkpoint_seconds and \
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
raise ValueError("The input_param can't be all None or 0")

if save_checkpoint_steps is not None:
save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps)
if save_checkpoint_seconds is not None:
@@ -117,6 +113,10 @@ class CheckpointConfig:
if keep_checkpoint_per_n_minutes is not None:
keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)

if not save_checkpoint_steps and not save_checkpoint_seconds and \
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
raise ValueError("The input_param can't be all None or 0")

self._save_checkpoint_steps = save_checkpoint_steps
self._save_checkpoint_seconds = save_checkpoint_seconds
if self._save_checkpoint_steps and self._save_checkpoint_steps > 0:
@@ -173,7 +173,6 @@ class CheckpointConfig:
return checkpoint_policy



class ModelCheckpoint(Callback):
"""
The checkpoint callback class.
@@ -203,7 +202,7 @@ class ModelCheckpoint(Callback):
raise ValueError("Prefix {} for checkpoint file name invalid, "
"please check and correct it and then continue.".format(prefix))

if directory:
if directory is not None:
self._directory = _make_directory(directory)
else:
self._directory = _cur_dir


正在加载...
取消
保存