From: @c_34 Reviewed-by: @guoqi1024,@wuxuejian Signed-off-by: @wuxuejiantags/v1.2.0-rc1
| @@ -46,6 +46,8 @@ parser.add_argument('--device_target', type=str, default='Ascend', choices=("Asc | |||||
| help="Device target, support Ascend, GPU and CPU.") | help="Device target, support Ascend, GPU and CPU.") | ||||
| parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') | parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') | ||||
| parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train') | parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train') | ||||
| parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, | |||||
| help="Filter head weight parameters, default is False.") | |||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| set_seed(1) | set_seed(1) | ||||
| @@ -74,6 +76,16 @@ if cfg.optimizer == "Thor": | |||||
| from src.config import config_thor_gpu as config | from src.config import config_thor_gpu as config | ||||
| def filter_checkpoint_parameter_by_list(origin_dict, param_filter): | |||||
| """remove useless parameters according to filter_list""" | |||||
| for key in list(origin_dict.keys()): | |||||
| for name in param_filter: | |||||
| if name in key: | |||||
| print("Delete parameter from checkpoint: ", key) | |||||
| del origin_dict[key] | |||||
| break | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| target = args_opt.device_target | target = args_opt.device_target | ||||
| if target == "CPU": | if target == "CPU": | ||||
| @@ -119,6 +131,9 @@ if __name__ == '__main__': | |||||
| # init weight | # init weight | ||||
| if args_opt.pre_trained: | if args_opt.pre_trained: | ||||
| param_dict = load_checkpoint(args_opt.pre_trained) | param_dict = load_checkpoint(args_opt.pre_trained) | ||||
| if args_opt.filter_weight: | |||||
| filter_list = [x.name for x in net.end_point.get_parameters()] | |||||
| filter_checkpoint_parameter_by_list(param_dict, filter_list) | |||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| else: | else: | ||||
| for _, cell in net.cells_and_names(): | for _, cell in net.cells_and_names(): | ||||
| @@ -50,6 +50,7 @@ config = ed({ | |||||
| # `mindrecord_dir` and `coco_root` are better to use absolute path. | # `mindrecord_dir` and `coco_root` are better to use absolute path. | ||||
| "feature_extractor_base_param": "", | "feature_extractor_base_param": "", | ||||
| "checkpoint_filter_list": ['multi_loc_layers', 'multi_cls_layers'], | |||||
| "mindrecord_dir": "/data/MindRecord_COCO", | "mindrecord_dir": "/data/MindRecord_COCO", | ||||
| "coco_root": "/data/coco2017", | "coco_root": "/data/coco2017", | ||||
| "train_data_type": "train2017", | "train_data_type": "train2017", | ||||
| @@ -54,6 +54,8 @@ config = ed({ | |||||
| # `mindrecord_dir` and `coco_root` are better to use absolute path. | # `mindrecord_dir` and `coco_root` are better to use absolute path. | ||||
| "feature_extractor_base_param": "/ckpt/mobilenet_v1.ckpt", | "feature_extractor_base_param": "/ckpt/mobilenet_v1.ckpt", | ||||
| "checkpoint_filter_list": ['network.multi_box.cls_layers.0.weight', 'network.multi_box.cls_layers.0.bias', | |||||
| 'network.multi_box.loc_layers.0.weight', 'network.multi_box.loc_layers.0.bias'], | |||||
| "mindrecord_dir": "/data/MindRecord_COCO", | "mindrecord_dir": "/data/MindRecord_COCO", | ||||
| "coco_root": "/data/coco2017", | "coco_root": "/data/coco2017", | ||||
| "train_data_type": "train2017", | "train_data_type": "train2017", | ||||
| @@ -39,8 +39,12 @@ def load_backbone_params(network, param_dict): | |||||
| if param_name in param_dict: | if param_name in param_dict: | ||||
| param.set_data(param_dict[param_name].data) | param.set_data(param_dict[param_name].data) | ||||
| def filter_checkpoint_parameter(param_dict): | |||||
| """remove useless parameters""" | |||||
| def filter_checkpoint_parameter_by_list(param_dict, filter_list): | |||||
| """remove useless parameters according to filter_list""" | |||||
| for key in list(param_dict.keys()): | for key in list(param_dict.keys()): | ||||
| if 'multi_loc_layers' in key or 'multi_cls_layers' in key: | |||||
| del param_dict[key] | |||||
| for name in filter_list: | |||||
| if name in key: | |||||
| print("Delete parameter from checkpoint: ", key) | |||||
| del param_dict[key] | |||||
| break | |||||
| @@ -29,7 +29,7 @@ from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2, | |||||
| from src.config import config | from src.config import config | ||||
| from src.dataset import create_ssd_dataset, create_mindrecord | from src.dataset import create_ssd_dataset, create_mindrecord | ||||
| from src.lr_schedule import get_lr | from src.lr_schedule import get_lr | ||||
| from src.init_params import init_net_param, filter_checkpoint_parameter | |||||
| from src.init_params import init_net_param, filter_checkpoint_parameter_by_list | |||||
| set_seed(1) | set_seed(1) | ||||
| @@ -45,7 +45,7 @@ def get_args(): | |||||
| parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") | parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") | ||||
| parser.add_argument("--lr", type=float, default=0.05, help="Learning rate, default is 0.05.") | parser.add_argument("--lr", type=float, default=0.05, help="Learning rate, default is 0.05.") | ||||
| parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.") | parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.") | ||||
| parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.") | |||||
| parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.") | |||||
| parser.add_argument("--epoch_size", type=int, default=500, help="Epoch size, default is 500.") | parser.add_argument("--epoch_size", type=int, default=500, help="Epoch size, default is 500.") | ||||
| parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") | parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") | ||||
| parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.") | parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.") | ||||
| @@ -122,8 +122,8 @@ def main(): | |||||
| if args_opt.pre_trained: | if args_opt.pre_trained: | ||||
| param_dict = load_checkpoint(args_opt.pre_trained) | param_dict = load_checkpoint(args_opt.pre_trained) | ||||
| if args_opt.filter_weight: | if args_opt.filter_weight: | ||||
| filter_checkpoint_parameter(param_dict) | |||||
| load_param_into_net(net, param_dict) | |||||
| filter_checkpoint_parameter_by_list(param_dict, config.checkpoint_filter_list) | |||||
| load_param_into_net(net, param_dict, True) | |||||
| if args_opt.freeze_layer == "backbone": | if args_opt.freeze_layer == "backbone": | ||||
| for param in backbone.feature_1.trainable_params(): | for param in backbone.feature_1.trainable_params(): | ||||