diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 90d8816094..8ec1b38804 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -224,42 +224,60 @@ def load_param_into_net(net, parameter_dict): msg = ("Argument parameter_dict should be a dict, but got {}.".format(type(parameter_dict))) raise TypeError(msg) - logger.info("Execute parameter into net process.") - param_name_net_not_have = [] + logger.info("Execute load parameter into net process.") for name in parameter_dict: - b_par_dict_have_par_of_net = False for _, param in net.parameters_and_names(): - if name == param.name: - b_par_dict_have_par_of_net = True + if name == param.name and param.layerwise_parallel: # layerwise parallel parameter data loaded from checkpoint file, # was a complete(merged) data, need to be splited - if param.layerwise_parallel: - new_param = parameter_dict[param.name] - _load_tensor_for_layerwise(new_param, param) + new_param = parameter_dict[param.name] + _load_tensor_for_layerwise(new_param, param) break - if not b_par_dict_have_par_of_net: - param_name_net_not_have.append(name) - param_name_param_dict_not_have = [] + param_not_load = [] for _, param in net.parameters_and_names(): if param.name in parameter_dict: new_param = parameter_dict[param.name] - if not isinstance(new_param, Parameter): logger.error("Failed to combine the net and the parameters.") msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param))) raise TypeError(msg) _update_param(param, new_param) else: - param_name_param_dict_not_have.append(param.name) + param_not_load.append(param.name) + + if param_not_load: + _load_dismatch_prefix_params(net, parameter_dict, param_not_load) logger.debug("Params not matched(in net but not in parameter_dict):") - for paramname in param_name_param_dict_not_have: - logger.debug("%s", paramname) - logger.debug("Params not matched(in parameter_dict but not in net):") - for paramname in param_name_net_not_have: - logger.debug("%s", paramname) - logger.info("Load parameter into net process finish.") + for param_name in param_not_load: + logger.debug("%s", param_name) + + logger.info("Load parameter into net finish, {} parameters has not been loaded.".format(len(param_not_load))) + + +def _load_dismatch_prefix_params(net, parameter_dict, param_not_load): + """When some net parameter did not load, try to continue load.""" + prefix_name = "" + longest_name = param_not_load[0] + while prefix_name != longest_name and param_not_load: + logger.debug("Count: {} parameters has not been loaded, try to load continue.".format(len(param_not_load))) + longest_name = sorted(param_not_load, key=len, reverse=True)[0] + prefix_name = longest_name + for net_param_name in param_not_load: + for dict_name in parameter_dict: + if dict_name.endswith(net_param_name): + tmp_name = dict_name[:-len(net_param_name)] + prefix_name = prefix_name if len(prefix_name) < len(tmp_name) else tmp_name + + if prefix_name != longest_name: + logger.info("Remove parameter prefix name: {}, continue to load.".format(prefix_name)) + for _, param in net.parameters_and_names(): + new_param_name = prefix_name + param.name + if param.name in param_not_load and new_param_name in parameter_dict: + new_param = parameter_dict[new_param_name] + _update_param(param, new_param) + param_not_load.remove(param.name) def _save_graph(network, file_name):