Browse Source

!7401 custom ckpt save and load

Merge pull request !7401 from caozhou/custom_ckpt_save_and_load
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
ba6023b87d
2 changed files with 75 additions and 13 deletions
  1. +44
    -8
      mindspore/train/callback/_checkpoint.py
  2. +31
    -5
      mindspore/train/serialization.py

+ 44
- 8
mindspore/train/callback/_checkpoint.py View File

@@ -21,6 +21,7 @@ import time
import threading import threading
import mindspore.context as context import mindspore.context as context
from mindspore import log as logger from mindspore import log as logger
from mindspore import nn
from mindspore._checkparam import Validator from mindspore._checkparam import Validator
from mindspore.train._utils import _make_directory from mindspore.train._utils import _make_directory
from mindspore.train.serialization import save_checkpoint, _save_graph from mindspore.train.serialization import save_checkpoint, _save_graph
@@ -88,13 +89,36 @@ class CheckpointConfig:
integrated_save (bool): Whether to perform integrated save function in automatic model parallel scene. integrated_save (bool): Whether to perform integrated save function in automatic model parallel scene.
Default: True. Integrated save function is only supported in automatic parallel scene, not supported Default: True. Integrated save function is only supported in automatic parallel scene, not supported
in manual parallel. in manual parallel.
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False.
saved_network (Cell): Network to be saved in checkpoint file. Default: None.


Raises: Raises:
ValueError: If the input_param is None or 0. ValueError: If the input_param is None or 0.


Examples: Examples:
>>> config = CheckpointConfig()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
>>> self.bn = nn.BatchNorm2d(64)
>>> self.relu = nn.ReLU()
>>> self.flatten = nn.Flatten()
>>> self.fc = nn.Dense(64*224*224, 12)
>>>
>>> def construct(self, x):
>>> x = self.conv(x)
>>> x = self.bn(x)
>>> x = self.relu(x)
>>> x = self.flatten(x)
>>> out = self.fc(x)
>>> return out
>>>
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> dataset = get_dataset()
>>> config = CheckpointConfig(saved_network=net)
>>> ckpoint_cb = ModelCheckpoint(prefix="ck_prefix", directory='./', config=config) >>> ckpoint_cb = ModelCheckpoint(prefix="ck_prefix", directory='./', config=config)
>>> model.train(10, dataset, callbacks=ckpoint_cb) >>> model.train(10, dataset, callbacks=ckpoint_cb)
""" """
@@ -104,7 +128,8 @@ class CheckpointConfig:
keep_checkpoint_max=5, keep_checkpoint_max=5,
keep_checkpoint_per_n_minutes=0, keep_checkpoint_per_n_minutes=0,
integrated_save=True, integrated_save=True,
async_save=False):
async_save=False,
saved_network=None):


if save_checkpoint_steps is not None: if save_checkpoint_steps is not None:
save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps) save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps)
@@ -115,6 +140,9 @@ class CheckpointConfig:
if keep_checkpoint_per_n_minutes is not None: if keep_checkpoint_per_n_minutes is not None:
keep_checkpoint_per_n_minutes = Validator.check_non_negative_int(keep_checkpoint_per_n_minutes) keep_checkpoint_per_n_minutes = Validator.check_non_negative_int(keep_checkpoint_per_n_minutes)


if saved_network is not None and not isinstance(saved_network, nn.Cell):
raise TypeError(f"The type of saved_network must be None or Cell, but got {str(type(saved_network))}.")

if not save_checkpoint_steps and not save_checkpoint_seconds and \ if not save_checkpoint_steps and not save_checkpoint_seconds and \
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
raise ValueError("The input_param can't be all None or 0") raise ValueError("The input_param can't be all None or 0")
@@ -134,6 +162,7 @@ class CheckpointConfig:


self._integrated_save = Validator.check_bool(integrated_save) self._integrated_save = Validator.check_bool(integrated_save)
self._async_save = Validator.check_bool(async_save) self._async_save = Validator.check_bool(async_save)
self._saved_network = saved_network


@property @property
def save_checkpoint_steps(self): def save_checkpoint_steps(self):
@@ -165,12 +194,18 @@ class CheckpointConfig:
"""Get the value of _async_save.""" """Get the value of _async_save."""
return self._async_save return self._async_save


@property
def saved_network(self):
"""Get the value of _saved_network"""
return self._saved_network

def get_checkpoint_policy(self): def get_checkpoint_policy(self):
"""Get the policy of checkpoint.""" """Get the policy of checkpoint."""
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
'save_checkpoint_seconds': self._save_checkpoint_seconds,
'keep_checkpoint_max': self._keep_checkpoint_max,
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes}
checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps,
'save_checkpoint_seconds': self.save_checkpoint_seconds,
'keep_checkpoint_max': self.keep_checkpoint_max,
'keep_checkpoint_per_n_minutes': self.keep_checkpoint_per_n_minutes,
'saved_network': self.saved_network}


return checkpoint_policy return checkpoint_policy


@@ -306,7 +341,8 @@ class ModelCheckpoint(Callback):
set_cur_net(cb_params.train_network) set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph() cb_params.train_network.exec_checkpoint_graph()


save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
save_checkpoint(network, cur_file, self._config.integrated_save,
self._config.async_save) self._config.async_save)


self._latest_ckpt_file_name = cur_file self._latest_ckpt_file_name = cur_file


+ 31
- 5
mindspore/train/serialization.py View File

@@ -225,7 +225,16 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
logger.info("Save checkpoint process finish.") logger.info("Save checkpoint process finish.")




def load_checkpoint(ckpt_file_name, net=None, strict_load=False):
def _check_param_prefix(filter_prefix, param_name):
"""Checks whether the prefix of parameter name matches the given filter_prefix."""
for prefix in filter_prefix:
if param_name.find(prefix) == 0 \
and (param_name == prefix or param_name[len(prefix)] == "." or (prefix and prefix[-1] == ".")):
return True
return False


def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None):
""" """
Loads checkpoint info from a specified file. Loads checkpoint info from a specified file.


@@ -234,6 +243,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False):
net (Cell): Cell network. Default: None net (Cell): Cell network. Default: None
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
in the param_dict into net with the same suffix. Default: False in the param_dict into net with the same suffix. Default: False
filter_prefix (Union[str, list[str], tuple[str]]): Parameter with the filter prefix will not be loaded.
Default: None.


Returns: Returns:
Dict, key is parameter name, value is a Parameter. Dict, key is parameter name, value is a Parameter.
@@ -253,6 +264,19 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False):
if os.path.getsize(ckpt_file_name) == 0: if os.path.getsize(ckpt_file_name) == 0:
raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.") raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.")


if filter_prefix is not None:
if not isinstance(filter_prefix, (str, list, tuple)):
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str] "
f"when filter_prefix is not None, but got {str(type(filter_prefix))}.")
if isinstance(filter_prefix, str):
filter_prefix = (filter_prefix,)
if not filter_prefix:
raise ValueError("The filter_prefix can't be empty when filter_prefix is list or tuple.")
for index, prefix in enumerate(filter_prefix):
if not isinstance(prefix, str):
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], "
f"but got {str(type(prefix))} at index {index}.")

logger.info("Execute load checkpoint process.") logger.info("Execute load checkpoint process.")
checkpoint_list = Checkpoint() checkpoint_list = Checkpoint()


@@ -266,9 +290,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False):


parameter_dict = {} parameter_dict = {}
try: try:
element_id = 0
param_data_list = [] param_data_list = []
for element in checkpoint_list.value:
for element_id, element in enumerate(checkpoint_list.value):
if filter_prefix is not None and _check_param_prefix(filter_prefix, element.tag):
continue
data = element.tensor.tensor_content data = element.tensor.tensor_content
data_type = element.tensor.tensor_type data_type = element.tensor.tensor_type
np_type = tensor_to_np_type[data_type] np_type = tensor_to_np_type[data_type]
@@ -296,14 +321,15 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False):
param_value = param_data.reshape(param_dim) param_value = param_data.reshape(param_dim)
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag) parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)


element_id += 1

logger.info("Load checkpoint process finish.") logger.info("Load checkpoint process finish.")


except BaseException as e: except BaseException as e:
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
raise RuntimeError(e.__str__()) raise RuntimeError(e.__str__())


if not parameter_dict:
raise ValueError(f"The loaded parameter dict is empty after filtering, please check filter_prefix.")

if net is not None: if net is not None:
load_param_into_net(net, parameter_dict, strict_load) load_param_into_net(net, parameter_dict, strict_load)




Loading…
Cancel
Save