From 0c71dd745ced6fbe0cfac9b2fbcda6b1a43710fd Mon Sep 17 00:00:00 2001 From: changzherui Date: Mon, 19 Oct 2020 15:59:02 +0800 Subject: [PATCH] add load ckpt param whether strict load --- mindspore/train/serialization.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index f63716056d..c2eb59b5b6 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -173,10 +173,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list): raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj))) - if not isinstance(integrated_save, bool): - raise TypeError("The parameter integrated_save should be bool, but got {}".format(type(integrated_save))) - if not isinstance(async_save, bool): - raise TypeError("The parameter async_save should be bool, but got {}".format(type(async_save))) + integrated_save = Validator.check_bool(integrated_save) + async_save = Validator.check_bool(async_save) logger.info("Execute save checkpoint process.") @@ -227,13 +225,15 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F logger.info("Save checkpoint process finish.") -def load_checkpoint(ckpt_file_name, net=None): +def load_checkpoint(ckpt_file_name, net=None, strict_load=False): """ Loads checkpoint info from a specified file. Args: ckpt_file_name (str): Checkpoint file name. net (Cell): Cell network. Default: None + strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter + in the param_dict into net with the same suffix. Default: False Returns: Dict, key is parameter name, value is a Parameter. @@ -305,18 +305,20 @@ def load_checkpoint(ckpt_file_name, net=None): raise RuntimeError(e.__str__()) if net is not None: - load_param_into_net(net, parameter_dict) + load_param_into_net(net, parameter_dict, strict_load) return parameter_dict -def load_param_into_net(net, parameter_dict): +def load_param_into_net(net, parameter_dict, strict_load=False): """ Loads parameters into network. Args: net (Cell): Cell network. parameter_dict (dict): Parameter dictionary. + strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter + in the param_dict into net with the same suffix. Default: False Raises: TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary. @@ -331,6 +333,7 @@ def load_param_into_net(net, parameter_dict): msg = ("Argument parameter_dict should be a dict, but got {}.".format(type(parameter_dict))) raise TypeError(msg) + strict_load = Validator.check_bool(strict_load) logger.info("Execute load parameter into net process.") net.init_parameters_data() param_not_load = [] @@ -345,7 +348,7 @@ def load_param_into_net(net, parameter_dict): else: param_not_load.append(param.name) - if param_not_load: + if param_not_load and not strict_load: _load_dismatch_prefix_params(net, parameter_dict, param_not_load) logger.debug("Params not matched(in net but not in parameter_dict):")