Browse Source

!7450 add load ckpt param whether strict load

Merge pull request !7450 from changzherui/add_strict
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
ca36c7494a
1 changed files with 11 additions and 8 deletions
  1. +11
    -8
      mindspore/train/serialization.py

+ 11
- 8
mindspore/train/serialization.py View File

@@ -173,10 +173,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F

if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj)))
if not isinstance(integrated_save, bool):
raise TypeError("The parameter integrated_save should be bool, but got {}".format(type(integrated_save)))
if not isinstance(async_save, bool):
raise TypeError("The parameter async_save should be bool, but got {}".format(type(async_save)))
integrated_save = Validator.check_bool(integrated_save)
async_save = Validator.check_bool(async_save)

logger.info("Execute save checkpoint process.")

@@ -227,13 +225,15 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
logger.info("Save checkpoint process finish.")


def load_checkpoint(ckpt_file_name, net=None):
def load_checkpoint(ckpt_file_name, net=None, strict_load=False):
"""
Loads checkpoint info from a specified file.

Args:
ckpt_file_name (str): Checkpoint file name.
net (Cell): Cell network. Default: None
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
in the param_dict into net with the same suffix. Default: False

Returns:
Dict, key is parameter name, value is a Parameter.
@@ -305,18 +305,20 @@ def load_checkpoint(ckpt_file_name, net=None):
raise RuntimeError(e.__str__())

if net is not None:
load_param_into_net(net, parameter_dict)
load_param_into_net(net, parameter_dict, strict_load)

return parameter_dict


def load_param_into_net(net, parameter_dict):
def load_param_into_net(net, parameter_dict, strict_load=False):
"""
Loads parameters into network.

Args:
net (Cell): Cell network.
parameter_dict (dict): Parameter dictionary.
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
in the param_dict into net with the same suffix. Default: False

Raises:
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
@@ -331,6 +333,7 @@ def load_param_into_net(net, parameter_dict):
msg = ("Argument parameter_dict should be a dict, but got {}.".format(type(parameter_dict)))
raise TypeError(msg)

strict_load = Validator.check_bool(strict_load)
logger.info("Execute load parameter into net process.")
net.init_parameters_data()
param_not_load = []
@@ -345,7 +348,7 @@ def load_param_into_net(net, parameter_dict):
else:
param_not_load.append(param.name)

if param_not_load:
if param_not_load and not strict_load:
_load_dismatch_prefix_params(net, parameter_dict, param_not_load)

logger.debug("Params not matched(in net but not in parameter_dict):")


Loading…
Cancel
Save