| @@ -173,10 +173,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F | |||
| if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list): | |||
| raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj))) | |||
| if not isinstance(integrated_save, bool): | |||
| raise TypeError("The parameter integrated_save should be bool, but got {}".format(type(integrated_save))) | |||
| if not isinstance(async_save, bool): | |||
| raise TypeError("The parameter async_save should be bool, but got {}".format(type(async_save))) | |||
| integrated_save = Validator.check_bool(integrated_save) | |||
| async_save = Validator.check_bool(async_save) | |||
| logger.info("Execute save checkpoint process.") | |||
| @@ -227,13 +225,15 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F | |||
| logger.info("Save checkpoint process finish.") | |||
| def load_checkpoint(ckpt_file_name, net=None): | |||
| def load_checkpoint(ckpt_file_name, net=None, strict_load=False): | |||
| """ | |||
| Loads checkpoint info from a specified file. | |||
| Args: | |||
| ckpt_file_name (str): Checkpoint file name. | |||
| net (Cell): Cell network. Default: None | |||
| strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter | |||
| in the param_dict into net with the same suffix. Default: False | |||
| Returns: | |||
| Dict, key is parameter name, value is a Parameter. | |||
| @@ -305,18 +305,20 @@ def load_checkpoint(ckpt_file_name, net=None): | |||
| raise RuntimeError(e.__str__()) | |||
| if net is not None: | |||
| load_param_into_net(net, parameter_dict) | |||
| load_param_into_net(net, parameter_dict, strict_load) | |||
| return parameter_dict | |||
| def load_param_into_net(net, parameter_dict): | |||
| def load_param_into_net(net, parameter_dict, strict_load=False): | |||
| """ | |||
| Loads parameters into network. | |||
| Args: | |||
| net (Cell): Cell network. | |||
| parameter_dict (dict): Parameter dictionary. | |||
| strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter | |||
| in the param_dict into net with the same suffix. Default: False | |||
| Raises: | |||
| TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary. | |||
| @@ -331,6 +333,7 @@ def load_param_into_net(net, parameter_dict): | |||
| msg = ("Argument parameter_dict should be a dict, but got {}.".format(type(parameter_dict))) | |||
| raise TypeError(msg) | |||
| strict_load = Validator.check_bool(strict_load) | |||
| logger.info("Execute load parameter into net process.") | |||
| net.init_parameters_data() | |||
| param_not_load = [] | |||
| @@ -345,7 +348,7 @@ def load_param_into_net(net, parameter_dict): | |||
| else: | |||
| param_not_load.append(param.name) | |||
| if param_not_load: | |||
| if param_not_load and not strict_load: | |||
| _load_dismatch_prefix_params(net, parameter_dict, param_not_load) | |||
| logger.debug("Params not matched(in net but not in parameter_dict):") | |||