Browse Source

fix deeplabv3 filter

pull/15473/head
jiangzhenguang 5 years ago
parent
commit
092ca26aa5
1 changed files with 4 additions and 1 deletions
  1. +4
    -1
      model_zoo/official/cv/deeplabv3/train.py

+ 4
- 1
model_zoo/official/cv/deeplabv3/train.py View File

@@ -141,8 +141,11 @@ def train():
if args.ckpt_pre_trained:
param_dict = load_checkpoint(args.ckpt_pre_trained)
if args.filter_weight:
filter_list = ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]
for key in list(param_dict.keys()):
if key in ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]:
for filter_key in filter_list:
if filter_key not in key:
continue
print('filter {}'.format(key))
del param_dict[key]
load_param_into_net(train_net, param_dict)


Loading…
Cancel
Save