Browse Source

!12560 Add checkpoint filter to resent50 and ssd

From: @c_34
Reviewed-by: @guoqi1024,@wuxuejian
Signed-off-by: @wuxuejian
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
cf70c14038
5 changed files with 30 additions and 8 deletions
  1. +15
    -0
      model_zoo/official/cv/resnet/train.py
  2. +1
    -0
      model_zoo/official/cv/ssd/src/config_ssd300.py
  3. +2
    -0
      model_zoo/official/cv/ssd/src/config_ssd_mobilenet_v1_fpn.py
  4. +8
    -4
      model_zoo/official/cv/ssd/src/init_params.py
  5. +4
    -4
      model_zoo/official/cv/ssd/train.py

+ 15
- 0
model_zoo/official/cv/resnet/train.py View File

@@ -46,6 +46,8 @@ parser.add_argument('--device_target', type=str, default='Ascend', choices=("Asc
help="Device target, support Ascend, GPU and CPU.")
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--parameter_server', type=ast.literal_eval, default=False, help='Run parameter server train')
parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
help="Filter head weight parameters, default is False.")
args_opt = parser.parse_args()

set_seed(1)
@@ -74,6 +76,16 @@ if cfg.optimizer == "Thor":
from src.config import config_thor_gpu as config


def filter_checkpoint_parameter_by_list(origin_dict, param_filter):
"""remove useless parameters according to filter_list"""
for key in list(origin_dict.keys()):
for name in param_filter:
if name in key:
print("Delete parameter from checkpoint: ", key)
del origin_dict[key]
break


if __name__ == '__main__':
target = args_opt.device_target
if target == "CPU":
@@ -119,6 +131,9 @@ if __name__ == '__main__':
# init weight
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
if args_opt.filter_weight:
filter_list = [x.name for x in net.end_point.get_parameters()]
filter_checkpoint_parameter_by_list(param_dict, filter_list)
load_param_into_net(net, param_dict)
else:
for _, cell in net.cells_and_names():


+ 1
- 0
model_zoo/official/cv/ssd/src/config_ssd300.py View File

@@ -50,6 +50,7 @@ config = ed({

# `mindrecord_dir` and `coco_root` are better to use absolute path.
"feature_extractor_base_param": "",
"checkpoint_filter_list": ['multi_loc_layers', 'multi_cls_layers'],
"mindrecord_dir": "/data/MindRecord_COCO",
"coco_root": "/data/coco2017",
"train_data_type": "train2017",


+ 2
- 0
model_zoo/official/cv/ssd/src/config_ssd_mobilenet_v1_fpn.py View File

@@ -54,6 +54,8 @@ config = ed({

# `mindrecord_dir` and `coco_root` are better to use absolute path.
"feature_extractor_base_param": "/ckpt/mobilenet_v1.ckpt",
"checkpoint_filter_list": ['network.multi_box.cls_layers.0.weight', 'network.multi_box.cls_layers.0.bias',
'network.multi_box.loc_layers.0.weight', 'network.multi_box.loc_layers.0.bias'],
"mindrecord_dir": "/data/MindRecord_COCO",
"coco_root": "/data/coco2017",
"train_data_type": "train2017",


+ 8
- 4
model_zoo/official/cv/ssd/src/init_params.py View File

@@ -39,8 +39,12 @@ def load_backbone_params(network, param_dict):
if param_name in param_dict:
param.set_data(param_dict[param_name].data)

def filter_checkpoint_parameter(param_dict):
"""remove useless parameters"""

def filter_checkpoint_parameter_by_list(param_dict, filter_list):
"""remove useless parameters according to filter_list"""
for key in list(param_dict.keys()):
if 'multi_loc_layers' in key or 'multi_cls_layers' in key:
del param_dict[key]
for name in filter_list:
if name in key:
print("Delete parameter from checkpoint: ", key)
del param_dict[key]
break

+ 4
- 4
model_zoo/official/cv/ssd/train.py View File

@@ -29,7 +29,7 @@ from src.ssd import SSD300, SSDWithLossCell, TrainingWrapper, ssd_mobilenet_v2,
from src.config import config
from src.dataset import create_ssd_dataset, create_mindrecord
from src.lr_schedule import get_lr
from src.init_params import init_net_param, filter_checkpoint_parameter
from src.init_params import init_net_param, filter_checkpoint_parameter_by_list

set_seed(1)

@@ -45,7 +45,7 @@ def get_args():
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--lr", type=float, default=0.05, help="Learning rate, default is 0.05.")
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, defalut is coco.")
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
parser.add_argument("--epoch_size", type=int, default=500, help="Epoch size, default is 500.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.")
@@ -122,8 +122,8 @@ def main():
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
if args_opt.filter_weight:
filter_checkpoint_parameter(param_dict)
load_param_into_net(net, param_dict)
filter_checkpoint_parameter_by_list(param_dict, config.checkpoint_filter_list)
load_param_into_net(net, param_dict, True)

if args_opt.freeze_layer == "backbone":
for param in backbone.feature_1.trainable_params():


Loading…
Cancel
Save