Browse Source

!2878 Asynchronous saving checkpoint

Merge pull request !2878 from mindspore_ding/checkpoint_mindspore_new
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
4936fe487f
2 changed files with 58 additions and 29 deletions
  1. +11
    -7
      mindspore/train/callback/_checkpoint.py
  2. +47
    -22
      mindspore/train/serialization.py

+ 11
- 7
mindspore/train/callback/_checkpoint.py View File

@@ -15,7 +15,6 @@
"""Checkpoint related classes and functions.""" """Checkpoint related classes and functions."""


import os import os
import shutil
import stat import stat
import time import time


@@ -86,6 +85,7 @@ class CheckpointConfig:
Can't be used with keep_checkpoint_max at the same time. Can't be used with keep_checkpoint_max at the same time.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True. integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False


Raises: Raises:
ValueError: If the input_param is None or 0. ValueError: If the input_param is None or 0.
@@ -100,7 +100,8 @@ class CheckpointConfig:
save_checkpoint_seconds=0, save_checkpoint_seconds=0,
keep_checkpoint_max=5, keep_checkpoint_max=5,
keep_checkpoint_per_n_minutes=0, keep_checkpoint_per_n_minutes=0,
integrated_save=True):
integrated_save=True,
async_save=False):


if not save_checkpoint_steps and not save_checkpoint_seconds and \ if not save_checkpoint_steps and not save_checkpoint_seconds and \
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
@@ -129,6 +130,7 @@ class CheckpointConfig:
self._keep_checkpoint_max = 1 self._keep_checkpoint_max = 1


self._integrated_save = check_bool(integrated_save) self._integrated_save = check_bool(integrated_save)
self._async_save = check_bool(async_save)


@property @property
def save_checkpoint_steps(self): def save_checkpoint_steps(self):
@@ -155,6 +157,11 @@ class CheckpointConfig:
"""Get the value of _integrated_save.""" """Get the value of _integrated_save."""
return self._integrated_save return self._integrated_save


@property
def async_save(self):
"""Get the value of _async_save."""
return self._async_save

def get_checkpoint_policy(self): def get_checkpoint_policy(self):
"""Get the policy of checkpoint.""" """Get the policy of checkpoint."""
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps, checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
@@ -282,8 +289,6 @@ class ModelCheckpoint(Callback):
global _save_dir global _save_dir
_save_dir = self._directory _save_dir = self._directory
cur_file = os.path.join(self._directory, cur_ckpoint_file) cur_file = os.path.join(self._directory, cur_ckpoint_file)
tmp_ckpt_file_name_for_cur_process = str(os.getpid()) + "-" + 'parameters.ckpt'
gen_file = os.path.join(_save_dir, tmp_ckpt_file_name_for_cur_process)
self._last_time_for_keep = time.time() self._last_time_for_keep = time.time()
self._last_triggered_step = cb_params.cur_step_num self._last_triggered_step = cb_params.cur_step_num


@@ -291,10 +296,9 @@ class ModelCheckpoint(Callback):
set_cur_net(cb_params.train_network) set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph() cb_params.train_network.exec_checkpoint_graph()


_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
_exec_save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
self._config.async_save)


if os.path.exists(gen_file):
shutil.move(gen_file, cur_file)
self._latest_ckpt_file_name = cur_file self._latest_ckpt_file_name = cur_file


@property @property


+ 47
- 22
mindspore/train/serialization.py View File

@@ -15,6 +15,7 @@
"""Model and parameters serialization.""" """Model and parameters serialization."""
import os import os
import stat import stat
from threading import Thread, Lock
import numpy as np import numpy as np


import mindspore.nn as nn import mindspore.nn as nn
@@ -40,6 +41,7 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64, "Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}


_ckpt_mutex = Lock()


def _special_process_par(par, new_par): def _special_process_par(par, new_par):
""" """
@@ -101,7 +103,29 @@ def _update_param(param, new_param):
param.set_parameter_data(type(param.data)(new_param.data)) param.set_parameter_data(type(param.data)(new_param.data))




def save_checkpoint(parameter_list, ckpt_file_name):
def _exec_save(ckpt_file_name, data_list):
"""Execute save checkpoint into file process."""
checkpoint_list = Checkpoint()

try:
with _ckpt_mutex:
for name, value in data_list.items():
param_value = checkpoint_list.value.add()
param_value.tag = name
param_tensor = param_value.tensor
param_tensor.dims.extend(value[0])
param_tensor.tensor_type = value[1]
param_tensor.tensor_content = value[2].tostring()

with open(ckpt_file_name, "wb") as f:
f.write(checkpoint_list.SerializeToString())
os.chmod(ckpt_file_name, stat.S_IRUSR)

except BaseException as e:
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
raise RuntimeError(e.__str__())

def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
""" """
Saves checkpoint info to a specified file. Saves checkpoint info to a specified file.


@@ -109,37 +133,37 @@ def save_checkpoint(parameter_list, ckpt_file_name):
parameter_list (list): Parameters list, each element is a dict parameter_list (list): Parameters list, each element is a dict
like {"name":xx, "type":xx, "shape":xx, "data":xx}. like {"name":xx, "type":xx, "shape":xx, "data":xx}.
ckpt_file_name (str): Checkpoint file name. ckpt_file_name (str): Checkpoint file name.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False


Raises: Raises:
RuntimeError: Failed to save the Checkpoint file. RuntimeError: Failed to save the Checkpoint file.
""" """
logger.info("Execute save checkpoint process.") logger.info("Execute save checkpoint process.")
checkpoint_list = Checkpoint()


try:
data_list = {}
with _ckpt_mutex:
for param in parameter_list: for param in parameter_list:
param_value = checkpoint_list.value.add()
param_value.tag = param["name"]
param_tensor = param_value.tensor
key = param["name"]
data_list[key] = []
if isinstance(param["data"], Parameter): if isinstance(param["data"], Parameter):
param["data"].init_data() param["data"].init_data()
param_data = param["data"].asnumpy().reshape(-1)
param_tensor.tensor_content = param_data.tostring()
param_tensor.tensor_type = str(param["data"].dtype)

dims = []
if param['data'].shape == (): if param['data'].shape == ():
param_tensor.dims.append(0)
dims.append(0)
else: else:
for dim in param['data'].shape: for dim in param['data'].shape:
param_tensor.dims.append(dim)

with open(ckpt_file_name, "wb") as f:
f.write(checkpoint_list.SerializeToString())
os.chmod(ckpt_file_name, stat.S_IRUSR)

except BaseException as e:
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
raise RuntimeError(e.__str__())
dims.append(dim)
data_list[key].append(dims)
tensor_type = str(param["data"].dtype)
data_list[key].append(tensor_type)
data = param["data"].asnumpy().reshape(-1)
data_list[key].append(data)

if async_save:
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list))
thr.start()
else:
_exec_save(ckpt_file_name, data_list)
logger.info("Save checkpoint process finish.") logger.info("Save checkpoint process finish.")




@@ -305,7 +329,7 @@ def _save_graph(network, file_name):
os.chmod(file_name, stat.S_IRUSR) os.chmod(file_name, stat.S_IRUSR)




def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True, async_save=False):
""" """
Saves checkpoint for 'ms' backend. Saves checkpoint for 'ms' backend.


@@ -313,6 +337,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
train_network (Network): The train network for training. train_network (Network): The train network for training.
ckpt_file_name (str): The name of checkpoint file. ckpt_file_name (str): The name of checkpoint file.
integrated_save (bool): Whether to integrated save in automatic model parallel scene. integrated_save (bool): Whether to integrated save in automatic model parallel scene.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False.
""" """


param_dict = {} param_dict = {}
@@ -336,7 +361,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
each_param["data"] = param_data each_param["data"] = param_data
param_list.append(each_param) param_list.append(each_param)


save_checkpoint(param_list, ckpt_file_name)
save_checkpoint(param_list, ckpt_file_name, async_save)




def _get_merged_param_data(net, param_name, param_data): def _get_merged_param_data(net, param_name, param_data):


Loading…
Cancel
Save