|
|
|
@@ -345,7 +345,9 @@ def load_param_into_net(net, parameter_dict, strict_load=False): |
|
|
|
>>> net = Net() |
|
|
|
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" |
|
|
|
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") |
|
|
|
>>> load_param_into_net(net, param_dict) |
|
|
|
>>> param_not_load = load_param_into_net(net, param_dict) |
|
|
|
>>> print(param_not_load) |
|
|
|
['conv1.weight'] |
|
|
|
""" |
|
|
|
if not isinstance(net, nn.Cell): |
|
|
|
logger.error("Failed to combine the net and the parameters.") |
|
|
|
|