| @@ -224,42 +224,60 @@ def load_param_into_net(net, parameter_dict): | |||
| msg = ("Argument parameter_dict should be a dict, but got {}.".format(type(parameter_dict))) | |||
| raise TypeError(msg) | |||
| logger.info("Execute parameter into net process.") | |||
| param_name_net_not_have = [] | |||
| logger.info("Execute load parameter into net process.") | |||
| for name in parameter_dict: | |||
| b_par_dict_have_par_of_net = False | |||
| for _, param in net.parameters_and_names(): | |||
| if name == param.name: | |||
| b_par_dict_have_par_of_net = True | |||
| if name == param.name and param.layerwise_parallel: | |||
| # layerwise parallel parameter data loaded from checkpoint file, | |||
| # was a complete(merged) data, need to be splited | |||
| if param.layerwise_parallel: | |||
| new_param = parameter_dict[param.name] | |||
| _load_tensor_for_layerwise(new_param, param) | |||
| new_param = parameter_dict[param.name] | |||
| _load_tensor_for_layerwise(new_param, param) | |||
| break | |||
| if not b_par_dict_have_par_of_net: | |||
| param_name_net_not_have.append(name) | |||
| param_name_param_dict_not_have = [] | |||
| param_not_load = [] | |||
| for _, param in net.parameters_and_names(): | |||
| if param.name in parameter_dict: | |||
| new_param = parameter_dict[param.name] | |||
| if not isinstance(new_param, Parameter): | |||
| logger.error("Failed to combine the net and the parameters.") | |||
| msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param))) | |||
| raise TypeError(msg) | |||
| _update_param(param, new_param) | |||
| else: | |||
| param_name_param_dict_not_have.append(param.name) | |||
| param_not_load.append(param.name) | |||
| if param_not_load: | |||
| _load_dismatch_prefix_params(net, parameter_dict, param_not_load) | |||
| logger.debug("Params not matched(in net but not in parameter_dict):") | |||
| for paramname in param_name_param_dict_not_have: | |||
| logger.debug("%s", paramname) | |||
| logger.debug("Params not matched(in parameter_dict but not in net):") | |||
| for paramname in param_name_net_not_have: | |||
| logger.debug("%s", paramname) | |||
| logger.info("Load parameter into net process finish.") | |||
| for param_name in param_not_load: | |||
| logger.debug("%s", param_name) | |||
| logger.info("Load parameter into net finish, {} parameters has not been loaded.".format(len(param_not_load))) | |||
| def _load_dismatch_prefix_params(net, parameter_dict, param_not_load): | |||
| """When some net parameter did not load, try to continue load.""" | |||
| prefix_name = "" | |||
| longest_name = param_not_load[0] | |||
| while prefix_name != longest_name and param_not_load: | |||
| logger.debug("Count: {} parameters has not been loaded, try to load continue.".format(len(param_not_load))) | |||
| longest_name = sorted(param_not_load, key=len, reverse=True)[0] | |||
| prefix_name = longest_name | |||
| for net_param_name in param_not_load: | |||
| for dict_name in parameter_dict: | |||
| if dict_name.endswith(net_param_name): | |||
| tmp_name = dict_name[:-len(net_param_name)] | |||
| prefix_name = prefix_name if len(prefix_name) < len(tmp_name) else tmp_name | |||
| if prefix_name != longest_name: | |||
| logger.info("Remove parameter prefix name: {}, continue to load.".format(prefix_name)) | |||
| for _, param in net.parameters_and_names(): | |||
| new_param_name = prefix_name + param.name | |||
| if param.name in param_not_load and new_param_name in parameter_dict: | |||
| new_param = parameter_dict[new_param_name] | |||
| _update_param(param, new_param) | |||
| param_not_load.remove(param.name) | |||
| def _save_graph(network, file_name): | |||