From 76ab430d4f38288560c3d1207fe5bd5565f8bc93 Mon Sep 17 00:00:00 2001 From: caozhou Date: Thu, 15 Oct 2020 16:59:29 +0800 Subject: [PATCH] custom ckpt save and load --- mindspore/train/callback/_checkpoint.py | 52 +++++++++++++++++++++---- mindspore/train/serialization.py | 36 ++++++++++++++--- 2 files changed, 75 insertions(+), 13 deletions(-) diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index c59440a51d..35cc198068 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -21,6 +21,7 @@ import time import threading import mindspore.context as context from mindspore import log as logger +from mindspore import nn from mindspore._checkparam import Validator from mindspore.train._utils import _make_directory from mindspore.train.serialization import save_checkpoint, _save_graph @@ -88,13 +89,36 @@ class CheckpointConfig: integrated_save (bool): Whether to perform integrated save function in automatic model parallel scene. Default: True. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. - async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False + async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False. + saved_network (Cell): Network to be saved in checkpoint file. Default: None. Raises: ValueError: If the input_param is None or 0. Examples: - >>> config = CheckpointConfig() + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') + >>> self.bn = nn.BatchNorm2d(64) + >>> self.relu = nn.ReLU() + >>> self.flatten = nn.Flatten() + >>> self.fc = nn.Dense(64*224*224, 12) + >>> + >>> def construct(self, x): + >>> x = self.conv(x) + >>> x = self.bn(x) + >>> x = self.relu(x) + >>> x = self.flatten(x) + >>> out = self.fc(x) + >>> return out + >>> + >>> net = Net() + >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + >>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9) + >>> model = Model(net, loss_fn=loss, optimizer=optim) + >>> dataset = get_dataset() + >>> config = CheckpointConfig(saved_network=net) >>> ckpoint_cb = ModelCheckpoint(prefix="ck_prefix", directory='./', config=config) >>> model.train(10, dataset, callbacks=ckpoint_cb) """ @@ -104,7 +128,8 @@ class CheckpointConfig: keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, integrated_save=True, - async_save=False): + async_save=False, + saved_network=None): if save_checkpoint_steps is not None: save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps) @@ -115,6 +140,9 @@ class CheckpointConfig: if keep_checkpoint_per_n_minutes is not None: keep_checkpoint_per_n_minutes = Validator.check_non_negative_int(keep_checkpoint_per_n_minutes) + if saved_network is not None and not isinstance(saved_network, nn.Cell): + raise TypeError(f"The type of saved_network must be None or Cell, but got {str(type(saved_network))}.") + if not save_checkpoint_steps and not save_checkpoint_seconds and \ not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: raise ValueError("The input_param can't be all None or 0") @@ -134,6 +162,7 @@ class CheckpointConfig: self._integrated_save = Validator.check_bool(integrated_save) self._async_save = Validator.check_bool(async_save) + self._saved_network = saved_network @property def save_checkpoint_steps(self): @@ -165,12 +194,18 @@ class CheckpointConfig: """Get the value of _async_save.""" return self._async_save + @property + def saved_network(self): + """Get the value of _saved_network""" + return self._saved_network + def get_checkpoint_policy(self): """Get the policy of checkpoint.""" - checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps, - 'save_checkpoint_seconds': self._save_checkpoint_seconds, - 'keep_checkpoint_max': self._keep_checkpoint_max, - 'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes} + checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps, + 'save_checkpoint_seconds': self.save_checkpoint_seconds, + 'keep_checkpoint_max': self.keep_checkpoint_max, + 'keep_checkpoint_per_n_minutes': self.keep_checkpoint_per_n_minutes, + 'saved_network': self.saved_network} return checkpoint_policy @@ -306,7 +341,8 @@ class ModelCheckpoint(Callback): set_cur_net(cb_params.train_network) cb_params.train_network.exec_checkpoint_graph() - save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save, + network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network + save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save) self._latest_ckpt_file_name = cur_file diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index c2eb59b5b6..c4d35b4c80 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -225,7 +225,16 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F logger.info("Save checkpoint process finish.") -def load_checkpoint(ckpt_file_name, net=None, strict_load=False): +def _check_param_prefix(filter_prefix, param_name): + """Checks whether the prefix of parameter name matches the given filter_prefix.""" + for prefix in filter_prefix: + if param_name.find(prefix) == 0 \ + and (param_name == prefix or param_name[len(prefix)] == "." or (prefix and prefix[-1] == ".")): + return True + return False + + +def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None): """ Loads checkpoint info from a specified file. @@ -234,6 +243,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False): net (Cell): Cell network. Default: None strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter in the param_dict into net with the same suffix. Default: False + filter_prefix (Union[str, list[str], tuple[str]]): Parameter with the filter prefix will not be loaded. + Default: None. Returns: Dict, key is parameter name, value is a Parameter. @@ -253,6 +264,19 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False): if os.path.getsize(ckpt_file_name) == 0: raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.") + if filter_prefix is not None: + if not isinstance(filter_prefix, (str, list, tuple)): + raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str] " + f"when filter_prefix is not None, but got {str(type(filter_prefix))}.") + if isinstance(filter_prefix, str): + filter_prefix = (filter_prefix,) + if not filter_prefix: + raise ValueError("The filter_prefix can't be empty when filter_prefix is list or tuple.") + for index, prefix in enumerate(filter_prefix): + if not isinstance(prefix, str): + raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], " + f"but got {str(type(prefix))} at index {index}.") + logger.info("Execute load checkpoint process.") checkpoint_list = Checkpoint() @@ -266,9 +290,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False): parameter_dict = {} try: - element_id = 0 param_data_list = [] - for element in checkpoint_list.value: + for element_id, element in enumerate(checkpoint_list.value): + if filter_prefix is not None and _check_param_prefix(filter_prefix, element.tag): + continue data = element.tensor.tensor_content data_type = element.tensor.tensor_type np_type = tensor_to_np_type[data_type] @@ -296,14 +321,15 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False): param_value = param_data.reshape(param_dim) parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag) - element_id += 1 - logger.info("Load checkpoint process finish.") except BaseException as e: logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) raise RuntimeError(e.__str__()) + if not parameter_dict: + raise ValueError(f"The loaded parameter dict is empty after filtering, please check filter_prefix.") + if net is not None: load_param_into_net(net, parameter_dict, strict_load)