|
|
|
@@ -173,10 +173,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F |
|
|
|
|
|
|
|
if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list): |
|
|
|
raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj))) |
|
|
|
if not isinstance(integrated_save, bool): |
|
|
|
raise TypeError("The parameter integrated_save should be bool, but got {}".format(type(integrated_save))) |
|
|
|
if not isinstance(async_save, bool): |
|
|
|
raise TypeError("The parameter async_save should be bool, but got {}".format(type(async_save))) |
|
|
|
integrated_save = Validator.check_bool(integrated_save) |
|
|
|
async_save = Validator.check_bool(async_save) |
|
|
|
|
|
|
|
logger.info("Execute save checkpoint process.") |
|
|
|
|
|
|
|
@@ -227,13 +225,15 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F |
|
|
|
logger.info("Save checkpoint process finish.") |
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(ckpt_file_name, net=None): |
|
|
|
def load_checkpoint(ckpt_file_name, net=None, strict_load=False): |
|
|
|
""" |
|
|
|
Loads checkpoint info from a specified file. |
|
|
|
|
|
|
|
Args: |
|
|
|
ckpt_file_name (str): Checkpoint file name. |
|
|
|
net (Cell): Cell network. Default: None |
|
|
|
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter |
|
|
|
in the param_dict into net with the same suffix. Default: False |
|
|
|
|
|
|
|
Returns: |
|
|
|
Dict, key is parameter name, value is a Parameter. |
|
|
|
@@ -305,18 +305,20 @@ def load_checkpoint(ckpt_file_name, net=None): |
|
|
|
raise RuntimeError(e.__str__()) |
|
|
|
|
|
|
|
if net is not None: |
|
|
|
load_param_into_net(net, parameter_dict) |
|
|
|
load_param_into_net(net, parameter_dict, strict_load) |
|
|
|
|
|
|
|
return parameter_dict |
|
|
|
|
|
|
|
|
|
|
|
def load_param_into_net(net, parameter_dict): |
|
|
|
def load_param_into_net(net, parameter_dict, strict_load=False): |
|
|
|
""" |
|
|
|
Loads parameters into network. |
|
|
|
|
|
|
|
Args: |
|
|
|
net (Cell): Cell network. |
|
|
|
parameter_dict (dict): Parameter dictionary. |
|
|
|
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter |
|
|
|
in the param_dict into net with the same suffix. Default: False |
|
|
|
|
|
|
|
Raises: |
|
|
|
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary. |
|
|
|
@@ -331,6 +333,7 @@ def load_param_into_net(net, parameter_dict): |
|
|
|
msg = ("Argument parameter_dict should be a dict, but got {}.".format(type(parameter_dict))) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
strict_load = Validator.check_bool(strict_load) |
|
|
|
logger.info("Execute load parameter into net process.") |
|
|
|
net.init_parameters_data() |
|
|
|
param_not_load = [] |
|
|
|
@@ -345,7 +348,7 @@ def load_param_into_net(net, parameter_dict): |
|
|
|
else: |
|
|
|
param_not_load.append(param.name) |
|
|
|
|
|
|
|
if param_not_load: |
|
|
|
if param_not_load and not strict_load: |
|
|
|
_load_dismatch_prefix_params(net, parameter_dict, param_not_load) |
|
|
|
|
|
|
|
logger.debug("Params not matched(in net but not in parameter_dict):") |
|
|
|
|