Browse Source

!3907 modify ckpt func check parameter

Merge pull request !3907 from changzherui/mod_ckpt_func_param
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
3dcea81721
2 changed files with 5 additions and 5 deletions
  1. +4
    -4
      mindspore/train/callback/_checkpoint.py
  2. +1
    -1
      mindspore/train/serialization.py

+ 4
- 4
mindspore/train/callback/_checkpoint.py View File

@@ -108,13 +108,13 @@ class CheckpointConfig:
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
raise ValueError("The input_param can't be all None or 0")

if save_checkpoint_steps:
if save_checkpoint_steps is not None:
save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps)
if save_checkpoint_seconds:
if save_checkpoint_seconds is not None:
save_checkpoint_seconds = check_int_non_negative(save_checkpoint_seconds)
if keep_checkpoint_max:
if keep_checkpoint_max is not None:
keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
if keep_checkpoint_per_n_minutes:
if keep_checkpoint_per_n_minutes is not None:
keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)

self._save_checkpoint_steps = save_checkpoint_steps


+ 1
- 1
mindspore/train/serialization.py View File

@@ -258,7 +258,7 @@ def load_checkpoint(ckpt_file_name, net=None):
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
raise RuntimeError(e.__str__())

if net:
if net is not None:
load_param_into_net(net, parameter_dict)

return parameter_dict


Loading…
Cancel
Save