Browse Source

!10604 modify save checkpoint file

From: @changzherui
Reviewed-by: @zhunaipan,@kingxian
Signed-off-by: @kingxian
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
8bd048cb0f
2 changed files with 2 additions and 4 deletions
  1. +0
    -4
      mindspore/train/_utils.py
  2. +2
    -0
      mindspore/train/callback/_checkpoint.py

+ 0
- 4
mindspore/train/_utils.py View File

@@ -76,20 +76,16 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_inf

def _make_directory(path: str):
"""Make directory."""
real_path = None
if path is None or not isinstance(path, str) or path.strip() == "":
logger.error("The path(%r) is invalid type.", path)
raise TypeError("Input path is invaild type")

# convert the relative paths
path = os.path.realpath(path)
logger.debug("The abs path is %r", path)

# check the path is exist and write permissions?
if os.path.exists(path):
real_path = path
else:
# All exceptions need to be caught because create directory maybe have some limit(permissions)
logger.debug("The directory(%s) doesn't exist, will create it", path)
try:
os.makedirs(path, exist_ok=True)


+ 2
- 0
mindspore/train/callback/_checkpoint.py View File

@@ -272,6 +272,7 @@ class ModelCheckpoint(Callback):
if _is_role_pserver():
self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix
cb_params = run_context.original_args()
_make_directory(self._directory)
# save graph (only once)
if not self._graph_saved:
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
@@ -279,6 +280,7 @@ class ModelCheckpoint(Callback):
os.remove(graph_file_name)
_save_graph(cb_params.train_network, graph_file_name)
self._graph_saved = True

self._save_ckpt(cb_params)

def end(self, run_context):


Loading…
Cancel
Save