From 97059addda71ec5ff6fb8cbc1413c98130015cb4 Mon Sep 17 00:00:00 2001 From: chenhaozhe Date: Mon, 22 Feb 2021 11:11:11 +0800 Subject: [PATCH] add filter-weight in ssd and resnet --- model_zoo/official/cv/resnet/train.py | 15 +++++++++++++++ model_zoo/official/cv/ssd/src/config_ssd300.py | 1 + .../cv/ssd/src/config_ssd_mobilenet_v1_fpn.py | 2 ++ model_zoo/official/cv/ssd/src/init_params.py | 12 ++++++++---- model_zoo/official/cv/ssd/train.py | 8 ++++---- 5 files changed, 30 insertions(+), 8 deletions(-) diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index f56b5d29f0..4f5380fe5c 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -46,6 +46,8 @@ parser.add_argument('--device_target', type=str, default='Ascend', choices=("Asc help="Device target, support Ascend, GPU and CPU.") parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train') +parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, + help="Filter head weight parameters, default is False.") args_opt = parser.parse_args() set_seed(1) @@ -74,6 +76,16 @@ if cfg.optimizer == "Thor": from src.config import config_thor_gpu as config +def filter_checkpoint_parameter_by_list(origin_dict, param_filter): + """remove useless parameters according to filter_list""" + for key in list(origin_dict.keys()): + for name in param_filter: + if name in key: + print("Delete parameter from checkpoint: ", key) + del origin_dict[key] + break + + if __name__ == '__main__': target = args_opt.device_target if target == "CPU": @@ -119,6 +131,9 @@ if __name__ == '__main__': # init weight if args_opt.pre_trained: param_dict = load_checkpoint(args_opt.pre_trained) + if args_opt.filter_weight: + filter_list = [x.name for x in net.end_point.get_parameters()] + filter_checkpoint_parameter_by_list(param_dict, filter_list) load_param_into_net(net, param_dict) else: for _, cell in net.cells_and_names(): diff --git a/model_zoo/official/cv/ssd/src/config_ssd300.py b/model_zoo/official/cv/ssd/src/config_ssd300.py index 036b70c01e..b4cc1e1c29 100644 --- a/model_zoo/official/cv/ssd/src/config_ssd300.py +++ b/model_zoo/official/cv/ssd/src/config_ssd300.py @@ -50,6 +50,7 @@ config = ed({ # `mindrecord_dir` and `coco_root` are better to use absolute path. "feature_extractor_base_param": "", + "checkpoint_filter_list": ['multi_loc_layers', 'multi_cls_layers'], "mindrecord_dir": "/data/MindRecord_COCO", "coco_root": "/data/coco2017", "train_data_type": "train2017", diff --git a/model_zoo/official/cv/ssd/src/config_ssd_mobilenet_v1_fpn.py b/model_zoo/official/cv/ssd/src/config_ssd_mobilenet_v1_fpn.py index 86c5ea38f6..48d3e60a89 100644 --- a/model_zoo/official/cv/ssd/src/config_ssd_mobilenet_v1_fpn.py +++ b/model_zoo/official/cv/ssd/src/config_ssd_mobilenet_v1_fpn.py @@ -54,6 +54,8 @@ config = ed({ # `mindrecord_dir` and `coco_root` are better to use absolute path. "feature_extractor_base_param": "/ckpt/mobilenet_v1.ckpt", + "checkpoint_filter_list": ['network.multi_box.cls_layers.0.weight', 'network.multi_box.cls_layers.0.bias', + 'network.multi_box.loc_layers.0.weight', 'network.multi_box.loc_layers.0.bias'], "mindrecord_dir": "/data/MindRecord_COCO", "coco_root": "/data/coco2017", "train_data_type": "train2017", diff --git a/model_zoo/official/cv/ssd/src/init_params.py b/model_zoo/official/cv/ssd/src/init_params.py index 6ffb2ed58f..e99cfe1722 100644 --- a/model_zoo/official/cv/ssd/src/init_params.py +++ b/model_zoo/official/cv/ssd/src/init_params.py @@ -39,8 +39,12 @@ def load_backbone_params(network, param_dict): if param_name in param_dict: param.set_data(param_dict[param_name].data) -def filter_checkpoint_parameter(param_dict): - """remove useless parameters""" + +def filter_checkpoint_parameter_by_list(param_dict, filter_list): + """remove useless parameters according to filter_list""" for key in list(param_dict.keys()): - if 'multi_loc_layers' in key or 'multi_cls_layers' in key: - del param_dict[key] + for name in filter_list: + if name in key: + print("Delete parameter from checkpoint: ", key) + del param_dict[key] + break diff --git a/model_zoo/official/cv/ssd/train.py b/model_zoo/official/cv/ssd/train.py index c34c76b4c0..971e7c5568 100644 --- a/model_zoo/official/cv/ssd/train.py +++ b/model_zoo/official/cv/ssd/train.py @@ -29,7 +29,7 @@ from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, from src.config import config from src.dataset import create_ssd_dataset, create_mindrecord from src.lr_schedule import get_lr -from src.init_params import init_net_param, filter_checkpoint_parameter +from src.init_params import init_net_param, filter_checkpoint_parameter_by_list set_seed(1) @@ -45,7 +45,7 @@ def get_args(): parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--lr", type=float, default=0.05, help="Learning rate, default is 0.05.") parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.") - parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.") + parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") parser.add_argument("--epoch_size", type=int, default=500, help="Epoch size, default is 500.") parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.") @@ -122,8 +122,8 @@ def main(): if args_opt.pre_trained: param_dict = load_checkpoint(args_opt.pre_trained) if args_opt.filter_weight: - filter_checkpoint_parameter(param_dict) - load_param_into_net(net, param_dict) + filter_checkpoint_parameter_by_list(param_dict, config.checkpoint_filter_list) + load_param_into_net(net, param_dict, True) if args_opt.freeze_layer == "backbone": for param in backbone.feature_1.trainable_params():