|
|
|
@@ -141,8 +141,11 @@ def train(): |
|
|
|
if args.ckpt_pre_trained:
|
|
|
|
param_dict = load_checkpoint(args.ckpt_pre_trained)
|
|
|
|
if args.filter_weight:
|
|
|
|
filter_list = ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]
|
|
|
|
for key in list(param_dict.keys()):
|
|
|
|
if key in ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]:
|
|
|
|
for filter_key in filter_list:
|
|
|
|
if filter_key not in key:
|
|
|
|
continue
|
|
|
|
print('filter {}'.format(key))
|
|
|
|
del param_dict[key]
|
|
|
|
load_param_into_net(train_net, param_dict)
|
|
|
|
|