Browse Source

!7450 add load ckpt param whether strict load

Merge pull request !7450 from changzherui/add_strict
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
ca36c7494a
1 changed files with 11 additions and 8 deletions
  1. +11
    -8
      mindspore/train/serialization.py

+ 11
- 8
mindspore/train/serialization.py View File

@@ -173,10 +173,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F


if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list): if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj))) raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj)))
if not isinstance(integrated_save, bool):
raise TypeError("The parameter integrated_save should be bool, but got {}".format(type(integrated_save)))
if not isinstance(async_save, bool):
raise TypeError("The parameter async_save should be bool, but got {}".format(type(async_save)))
integrated_save = Validator.check_bool(integrated_save)
async_save = Validator.check_bool(async_save)


logger.info("Execute save checkpoint process.") logger.info("Execute save checkpoint process.")


@@ -227,13 +225,15 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
logger.info("Save checkpoint process finish.") logger.info("Save checkpoint process finish.")




def load_checkpoint(ckpt_file_name, net=None):
def load_checkpoint(ckpt_file_name, net=None, strict_load=False):
""" """
Loads checkpoint info from a specified file. Loads checkpoint info from a specified file.


Args: Args:
ckpt_file_name (str): Checkpoint file name. ckpt_file_name (str): Checkpoint file name.
net (Cell): Cell network. Default: None net (Cell): Cell network. Default: None
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
in the param_dict into net with the same suffix. Default: False


Returns: Returns:
Dict, key is parameter name, value is a Parameter. Dict, key is parameter name, value is a Parameter.
@@ -305,18 +305,20 @@ def load_checkpoint(ckpt_file_name, net=None):
raise RuntimeError(e.__str__()) raise RuntimeError(e.__str__())


if net is not None: if net is not None:
load_param_into_net(net, parameter_dict)
load_param_into_net(net, parameter_dict, strict_load)


return parameter_dict return parameter_dict




def load_param_into_net(net, parameter_dict):
def load_param_into_net(net, parameter_dict, strict_load=False):
""" """
Loads parameters into network. Loads parameters into network.


Args: Args:
net (Cell): Cell network. net (Cell): Cell network.
parameter_dict (dict): Parameter dictionary. parameter_dict (dict): Parameter dictionary.
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
in the param_dict into net with the same suffix. Default: False


Raises: Raises:
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary. TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
@@ -331,6 +333,7 @@ def load_param_into_net(net, parameter_dict):
msg = ("Argument parameter_dict should be a dict, but got {}.".format(type(parameter_dict))) msg = ("Argument parameter_dict should be a dict, but got {}.".format(type(parameter_dict)))
raise TypeError(msg) raise TypeError(msg)


strict_load = Validator.check_bool(strict_load)
logger.info("Execute load parameter into net process.") logger.info("Execute load parameter into net process.")
net.init_parameters_data() net.init_parameters_data()
param_not_load = [] param_not_load = []
@@ -345,7 +348,7 @@ def load_param_into_net(net, parameter_dict):
else: else:
param_not_load.append(param.name) param_not_load.append(param.name)


if param_not_load:
if param_not_load and not strict_load:
_load_dismatch_prefix_params(net, parameter_dict, param_not_load) _load_dismatch_prefix_params(net, parameter_dict, param_not_load)


logger.debug("Params not matched(in net but not in parameter_dict):") logger.debug("Params not matched(in net but not in parameter_dict):")


Loading…
Cancel
Save