From: @zhao_ting_v Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -430,6 +430,14 @@ epoch: 3 step: 1, loss is 1.5099041 | |||||
| ... | ... | ||||
| ``` | ``` | ||||
| #### Transfer Training | |||||
| You can train your own model based on pretrained model. You can perform transfer training by following steps. | |||||
| 1. Convert your own dataset to Pascal VOC datasets. Otherwise you have to add your own data preprocess code. | |||||
| 2. Set argument `filter_weight` to `True`, `ckpt_pre_trained` to pretrained checkpoint and `num_classes` to the classes of your dataset while calling `train.py`, this will filter the final conv weight from the pretrained model. | |||||
| 3. Build your own bash scripts using new config and arguments for further convenient. | |||||
| ## [Evaluation Process](#contents) | ## [Evaluation Process](#contents) | ||||
| ### Usage | ### Usage | ||||
| @@ -375,6 +375,14 @@ python ${train_code_path}/train.py --data_file=/PATH/TO/MINDRECORD_NAME \ | |||||
| --keep_checkpoint_max=200 >log 2>&1 & | --keep_checkpoint_max=200 >log 2>&1 & | ||||
| ``` | ``` | ||||
| #### 迁移训练 | |||||
| 用户可以根据预训练好的checkpoint进行迁移学习, 步骤如下: | |||||
| 1. 将数据集格式转换为上述VOC数据集格式,或者自行添加数据处理代码。 | |||||
| 2. 运行`train.py`时设置 `filter_weight` 为 `True`, `ckpt_pre_trained` 为预训练模型路径,`num_classes` 为数据集匹配的类别数目, 加载checkpoint中参数时过滤掉最后的卷积的权重。 | |||||
| 3. 重写启动脚本。 | |||||
| ### 结果 | ### 结果 | ||||
| #### Ascend处理器环境运行 | #### Ascend处理器环境运行 | ||||
| @@ -16,6 +16,7 @@ | |||||
| import os | import os | ||||
| import argparse | import argparse | ||||
| import ast | |||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| @@ -73,6 +74,8 @@ def parse_args(): | |||||
| parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model') | parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model') | ||||
| parser.add_argument('--freeze_bn', action='store_true', help='freeze bn') | parser.add_argument('--freeze_bn', action='store_true', help='freeze bn') | ||||
| parser.add_argument('--ckpt_pre_trained', type=str, default='', help='pretrained model') | parser.add_argument('--ckpt_pre_trained', type=str, default='', help='pretrained model') | ||||
| parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, | |||||
| help="Filter the last weight parameters, default is False.") | |||||
| # train | # train | ||||
| parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'], | parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'CPU'], | ||||
| @@ -137,7 +140,13 @@ def train(): | |||||
| # load pretrained model | # load pretrained model | ||||
| if args.ckpt_pre_trained: | if args.ckpt_pre_trained: | ||||
| param_dict = load_checkpoint(args.ckpt_pre_trained) | param_dict = load_checkpoint(args.ckpt_pre_trained) | ||||
| if args.filter_weight: | |||||
| for key in list(param_dict.keys()): | |||||
| if key in ["network.aspp.conv2.weight", "network.aspp.conv2.bias"]: | |||||
| print('filter {}'.format(key)) | |||||
| del param_dict[key] | |||||
| load_param_into_net(train_net, param_dict) | load_param_into_net(train_net, param_dict) | ||||
| print('load_model {} success'.format(args.ckpt_pre_trained)) | |||||
| # optimizer | # optimizer | ||||
| iters_per_epoch = dataset.get_dataset_size() | iters_per_epoch = dataset.get_dataset_size() | ||||
| @@ -315,7 +315,7 @@ epoch time: 150753.701, per step time: 329.157 | |||||
| You can train your own model based on either pretrained classification model or pretrained detection model. You can perform transfer training by following steps. | You can train your own model based on either pretrained classification model or pretrained detection model. You can perform transfer training by following steps. | ||||
| 1. Convert your own dataset to COCO or VOC style. Otherwise you havet to add your own data preprocess code. | |||||
| 1. Convert your own dataset to COCO or VOC style. Otherwise you have to add your own data preprocess code. | |||||
| 2. Change config.py according to your own dataset, especially the `num_classes`. | 2. Change config.py according to your own dataset, especially the `num_classes`. | ||||
| 3. Set argument `filter_weight` to `True` while calling `train.py`, this will filter the final detection box weight from the pretrained model. | 3. Set argument `filter_weight` to `True` while calling `train.py`, this will filter the final detection box weight from the pretrained model. | ||||
| 4. Build your own bash scripts using new config and arguments for further convenient. | 4. Build your own bash scripts using new config and arguments for further convenient. | ||||
| @@ -320,6 +320,15 @@ The above shell script will run distribute training in the background. You can v | |||||
| ... | ... | ||||
| ``` | ``` | ||||
| ### Transfer Training | |||||
| You can train your own model based on either pretrained classification model or pretrained detection model. You can perform transfer training by following steps. | |||||
| 1. Convert your own dataset to COCO style. Otherwise you have to add your own data preprocess code. | |||||
| 2. Change config.py according to your own dataset, especially the `num_classes`. | |||||
| 3. Set argument `filter_weight` to `True` and `pretrained_checkpoint` to pretrained checkpoint while calling `train.py`, this will filter the final detection box weight from the pretrained model. | |||||
| 4. Build your own bash scripts using new config and arguments for further convenient. | |||||
| ## [Evaluation Process](#contents) | ## [Evaluation Process](#contents) | ||||
| ### Valid | ### Valid | ||||
| @@ -67,3 +67,9 @@ class ConfigYOLOV4CspDarkNet53: | |||||
| # test_param | # test_param | ||||
| test_img_shape = [608, 608] | test_img_shape = [608, 608] | ||||
| # transfer training | |||||
| checkpoint_filter_list = ['feature_map.backblock0.conv6.weight', 'feature_map.backblock0.conv6.bias', | |||||
| 'feature_map.backblock1.conv6.weight', 'feature_map.backblock1.conv6.bias', | |||||
| 'feature_map.backblock2.conv6.weight', 'feature_map.backblock2.conv6.bias', | |||||
| 'feature_map.backblock3.conv6.weight', 'feature_map.backblock3.conv6.bias'] | |||||
| @@ -202,3 +202,15 @@ def load_yolov4_params(args, network): | |||||
| args.logger.info('resume finished') | args.logger.info('resume finished') | ||||
| load_param_into_net(network, param_dict_new) | load_param_into_net(network, param_dict_new) | ||||
| args.logger.info('load_model {} success'.format(args.resume_yolov4)) | args.logger.info('load_model {} success'.format(args.resume_yolov4)) | ||||
| if args.filter_weight: | |||||
| if args.pretrained_checkpoint: | |||||
| param_dict = load_checkpoint(args.pretrained_checkpoint) | |||||
| for key in list(param_dict.keys()): | |||||
| if key in args.checkpoint_filter_list: | |||||
| args.logger.info('filter {}'.format(key)) | |||||
| del param_dict[key] | |||||
| load_param_into_net(network, param_dict) | |||||
| args.logger.info('load_model {} success'.format(args.pretrained_checkpoint)) | |||||
| else: | |||||
| args.logger.warning('Set filter_weight, but not load pretrained_checkpoint, please be careful') | |||||
| @@ -17,6 +17,7 @@ import os | |||||
| import time | import time | ||||
| import argparse | import argparse | ||||
| import datetime | import datetime | ||||
| import ast | |||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore.nn.optim.momentum import Momentum | from mindspore.nn.optim.momentum import Momentum | ||||
| @@ -59,6 +60,10 @@ parser.add_argument('--pretrained_backbone', default='', type=str, | |||||
| help='The ckpt file of CspDarkNet53. Default: "".') | help='The ckpt file of CspDarkNet53. Default: "".') | ||||
| parser.add_argument('--resume_yolov4', default='', type=str, | parser.add_argument('--resume_yolov4', default='', type=str, | ||||
| help='The ckpt file of YOLOv4, which used to fine tune. Default: ""') | help='The ckpt file of YOLOv4, which used to fine tune. Default: ""') | ||||
| parser.add_argument('--pretrained_checkpoint', default='', type=str, | |||||
| help='The ckpt file of YoloV4CspDarkNet53. Default: "".') | |||||
| parser.add_argument("--filter_weight", type=ast.literal_eval, default=False, | |||||
| help="Filter the last weight parameters, default is False.") | |||||
| # optimizer and lr related | # optimizer and lr related | ||||
| parser.add_argument('--lr_scheduler', default='cosine_annealing', type=str, | parser.add_argument('--lr_scheduler', default='cosine_annealing', type=str, | ||||
| @@ -173,14 +178,14 @@ if __name__ == "__main__": | |||||
| network = YOLOV4CspDarkNet53(is_training=True) | network = YOLOV4CspDarkNet53(is_training=True) | ||||
| # default is kaiming-normal | # default is kaiming-normal | ||||
| config = ConfigYOLOV4CspDarkNet53() | |||||
| args.checkpoint_filter_list = config.checkpoint_filter_list | |||||
| default_recurisive_init(network) | default_recurisive_init(network) | ||||
| load_yolov4_params(args, network) | load_yolov4_params(args, network) | ||||
| network = YoloWithLossCell(network) | network = YoloWithLossCell(network) | ||||
| args.logger.info('finish get network') | args.logger.info('finish get network') | ||||
| config = ConfigYOLOV4CspDarkNet53() | |||||
| config.label_smooth = args.label_smooth | config.label_smooth = args.label_smooth | ||||
| config.label_smooth_factor = args.label_smooth_factor | config.label_smooth_factor = args.label_smooth_factor | ||||