Browse Source

!15473 fix deeplabv3 filter

From: @jiangzg001
Reviewed-by: @wuxuejian,@linqingke
Signed-off-by: @wuxuejian
pull/15473/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
b98c508e86
1 changed files with 4 additions and 1 deletions
  1. +4
    -1
      model_zoo/official/cv/deeplabv3/train.py

+ 4
- 1
model_zoo/official/cv/deeplabv3/train.py View File

@@ -141,8 +141,11 @@ def train():
if args.ckpt_pre_trained:
param_dict = load_checkpoint(args.ckpt_pre_trained)
if args.filter_weight:
filter_list = ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]
for key in list(param_dict.keys()):
if key in ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]:
for filter_key in filter_list:
if filter_key not in key:
continue
print('filter {}'.format(key))
del param_dict[key]
load_param_into_net(train_net, param_dict)


Loading…
Cancel
Save