|
|
|
@@ -224,42 +224,60 @@ def load_param_into_net(net, parameter_dict): |
|
|
|
msg = ("Argument parameter_dict should be a dict, but got {}.".format(type(parameter_dict))) |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
logger.info("Execute parameter into net process.") |
|
|
|
param_name_net_not_have = [] |
|
|
|
logger.info("Execute load parameter into net process.") |
|
|
|
for name in parameter_dict: |
|
|
|
b_par_dict_have_par_of_net = False |
|
|
|
for _, param in net.parameters_and_names(): |
|
|
|
if name == param.name: |
|
|
|
b_par_dict_have_par_of_net = True |
|
|
|
if name == param.name and param.layerwise_parallel: |
|
|
|
# layerwise parallel parameter data loaded from checkpoint file, |
|
|
|
# was a complete(merged) data, need to be splited |
|
|
|
if param.layerwise_parallel: |
|
|
|
new_param = parameter_dict[param.name] |
|
|
|
_load_tensor_for_layerwise(new_param, param) |
|
|
|
new_param = parameter_dict[param.name] |
|
|
|
_load_tensor_for_layerwise(new_param, param) |
|
|
|
break |
|
|
|
if not b_par_dict_have_par_of_net: |
|
|
|
param_name_net_not_have.append(name) |
|
|
|
|
|
|
|
param_name_param_dict_not_have = [] |
|
|
|
param_not_load = [] |
|
|
|
for _, param in net.parameters_and_names(): |
|
|
|
if param.name in parameter_dict: |
|
|
|
new_param = parameter_dict[param.name] |
|
|
|
|
|
|
|
if not isinstance(new_param, Parameter): |
|
|
|
logger.error("Failed to combine the net and the parameters.") |
|
|
|
msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param))) |
|
|
|
raise TypeError(msg) |
|
|
|
_update_param(param, new_param) |
|
|
|
else: |
|
|
|
param_name_param_dict_not_have.append(param.name) |
|
|
|
param_not_load.append(param.name) |
|
|
|
|
|
|
|
if param_not_load: |
|
|
|
_load_dismatch_prefix_params(net, parameter_dict, param_not_load) |
|
|
|
|
|
|
|
logger.debug("Params not matched(in net but not in parameter_dict):") |
|
|
|
for paramname in param_name_param_dict_not_have: |
|
|
|
logger.debug("%s", paramname) |
|
|
|
logger.debug("Params not matched(in parameter_dict but not in net):") |
|
|
|
for paramname in param_name_net_not_have: |
|
|
|
logger.debug("%s", paramname) |
|
|
|
logger.info("Load parameter into net process finish.") |
|
|
|
for param_name in param_not_load: |
|
|
|
logger.debug("%s", param_name) |
|
|
|
|
|
|
|
logger.info("Load parameter into net finish, {} parameters has not been loaded.".format(len(param_not_load))) |
|
|
|
|
|
|
|
|
|
|
|
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load): |
|
|
|
"""When some net parameter did not load, try to continue load.""" |
|
|
|
prefix_name = "" |
|
|
|
longest_name = param_not_load[0] |
|
|
|
while prefix_name != longest_name and param_not_load: |
|
|
|
logger.debug("Count: {} parameters has not been loaded, try to load continue.".format(len(param_not_load))) |
|
|
|
longest_name = sorted(param_not_load, key=len, reverse=True)[0] |
|
|
|
prefix_name = longest_name |
|
|
|
for net_param_name in param_not_load: |
|
|
|
for dict_name in parameter_dict: |
|
|
|
if dict_name.endswith(net_param_name): |
|
|
|
tmp_name = dict_name[:-len(net_param_name)] |
|
|
|
prefix_name = prefix_name if len(prefix_name) < len(tmp_name) else tmp_name |
|
|
|
|
|
|
|
if prefix_name != longest_name: |
|
|
|
logger.info("Remove parameter prefix name: {}, continue to load.".format(prefix_name)) |
|
|
|
for _, param in net.parameters_and_names(): |
|
|
|
new_param_name = prefix_name + param.name |
|
|
|
if param.name in param_not_load and new_param_name in parameter_dict: |
|
|
|
new_param = parameter_dict[new_param_name] |
|
|
|
_update_param(param, new_param) |
|
|
|
param_not_load.remove(param.name) |
|
|
|
|
|
|
|
|
|
|
|
def _save_graph(network, file_name): |
|
|
|
|