From 092ca26aa5ceea07915963488e9a379ac60d4158 Mon Sep 17 00:00:00 2001 From: jiangzhenguang Date: Wed, 21 Apr 2021 15:57:25 +0800 Subject: [PATCH] fix deeplabv3 filter --- model_zoo/official/cv/deeplabv3/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/model_zoo/official/cv/deeplabv3/train.py b/model_zoo/official/cv/deeplabv3/train.py index 0b41cee565..4139246d01 100644 --- a/model_zoo/official/cv/deeplabv3/train.py +++ b/model_zoo/official/cv/deeplabv3/train.py @@ -141,8 +141,11 @@ def train(): if args.ckpt_pre_trained: param_dict = load_checkpoint(args.ckpt_pre_trained) if args.filter_weight: + filter_list = ["network.aspp.conv2.weight", "network.aspp.conv2.bias"] for key in list(param_dict.keys()): - if key in ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]: + for filter_key in filter_list: + if filter_key not in key: + continue print('filter {}'.format(key)) del param_dict[key] load_param_into_net(train_net, param_dict)