You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

3_exist_data_new_model.md 10 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. # 3: 在标准数据集上训练自定义模型
  2. 在本文中,你将知道如何在标准数据集上训练、测试和推理自定义模型。我们将在 cityscapes 数据集上以自定义 Cascade Mask R-CNN R50 模型为例演示整个过程,为了方便说明,我们将 neck 模块中的 `FPN` 替换为 `AugFPN`,并且在训练中的自动增强类中增加 `Rotate` 或 `Translate`。
  3. 基本步骤如下所示:
  4. 1. 准备标准数据集
  5. 2. 准备你的自定义模型
  6. 3. 准备配置文件
  7. 4. 在标准数据集上对模型进行训练、测试和推理
  8. ## 准备标准数据集
  9. 在本文中,我们使用 cityscapes 标准数据集为例进行说明。
  10. 推荐将数据集根路径采用符号链接方式链接到 `$MMDETECTION/data`。
  11. 如果你的文件结构不同,你可能需要在配置文件中进行相应的路径更改。标准的文件组织格式如下所示:
  12. ```none
  13. mmdetection
  14. ├── mmdet
  15. ├── tools
  16. ├── configs
  17. ├── data
  18. │ ├── coco
  19. │ │ ├── annotations
  20. │ │ ├── train2017
  21. │ │ ├── val2017
  22. │ │ ├── test2017
  23. │ ├── cityscapes
  24. │ │ ├── annotations
  25. │ │ ├── leftImg8bit
  26. │ │ │ ├── train
  27. │ │ │ ├── val
  28. │ │ ├── gtFine
  29. │ │ │ ├── train
  30. │ │ │ ├── val
  31. │ ├── VOCdevkit
  32. │ │ ├── VOC2007
  33. │ │ ├── VOC2012
  34. ```
  35. 你需要使用脚本 `tools/dataset_converters/cityscapes.py` 将 cityscapes 标注转化为 coco 标注格式。
  36. ```shell
  37. pip install cityscapesscripts
  38. python tools/dataset_converters/cityscapes.py ./data/cityscapes --nproc 8 --out-dir ./data/cityscapes/annotations
  39. ```
  40. 目前在 `cityscapes `文件夹中的配置文件所对应模型是采用 COCO 预训练权重进行初始化的。
  41. 如果你的网络不可用或者比较慢,建议你先手动下载对应的预训练权重,否则可能在训练开始时候出现错误。
  42. ## 准备你的自定义模型
  43. 第二步是准备你的自定义模型或者训练相关配置。假设你想在已有的 Cascade Mask R-CNN R50 检测模型基础上,新增一个新的 neck 模块 `AugFPN` 去代替默认的 `FPN`,以下是具体实现:
  44. ### 1 定义新的 neck (例如 AugFPN)
  45. 首先创建新文件 `mmdet/models/necks/augfpn.py`.
  46. ```python
  47. from ..builder import NECKS
  48. @NECKS.register_module()
  49. class AugFPN(nn.Module):
  50. def __init__(self,
  51. in_channels,
  52. out_channels,
  53. num_outs,
  54. start_level=0,
  55. end_level=-1,
  56. add_extra_convs=False):
  57. pass
  58. def forward(self, inputs):
  59. # implementation is ignored
  60. pass
  61. ```
  62. ### 2 导入模块
  63. 你可以采用两种方式导入模块,第一种是在 `mmdet/models/necks/__init__.py` 中添加如下内容
  64. ```python
  65. from .augfpn import AugFPN
  66. ```
  67. 第二种是增加如下代码到对应配置中,这种方式的好处是不需要改动代码
  68. ```python
  69. custom_imports = dict(
  70. imports=['mmdet.models.necks.augfpn.py'],
  71. allow_failed_imports=False)
  72. ```
  73. ### 3 修改配置
  74. ```python
  75. neck=dict(
  76. type='AugFPN',
  77. in_channels=[256, 512, 1024, 2048],
  78. out_channels=256,
  79. num_outs=5)
  80. ```
  81. 关于自定义模型其余相关细节例如实现新的骨架网络,头部网络、损失函数,以及运行时训练配置例如定义新的优化器、使用梯度裁剪、定制训练调度策略和钩子等,请参考文档 [自定义模型](tutorials/customize_models.md) 和 [自定义运行时训练配置](tutorials/customize_runtime.md)。
  82. ## 准备配置文件
  83. 第三步是准备训练配置所需要的配置文件。假设你打算基于 cityscapes 数据集,在 Cascade Mask R-CNN R50 中新增 `AugFPN` 模块,同时增加 `Rotate` 或者 `Translate` 数据增强策略,假设你的配置文件位于 `configs/cityscapes/` 目录下,并且取名为 `cascade_mask_rcnn_r50_augfpn_autoaug_10e_cityscapes.py`,则配置信息如下:
  84. ```python
  85. # 继承 base 配置,然后进行针对性修改
  86. _base_ = [
  87. '../_base_/models/cascade_mask_rcnn_r50_fpn.py',
  88. '../_base_/datasets/cityscapes_instance.py', '../_base_/default_runtime.py'
  89. ]
  90. model = dict(
  91. # 设置为 None,表示不加载 ImageNet 预训练权重,
  92. # 后续可以设置 `load_from` 参数用来加载 COCO 预训练权重
  93. backbone=dict(init_cfg=None),
  94. pretrained=None,
  95. # 使用新增的 `AugFPN` 模块代替默认的 `FPN`
  96. neck=dict(
  97. type='AugFPN',
  98. in_channels=[256, 512, 1024, 2048],
  99. out_channels=256,
  100. num_outs=5),
  101. # 我们也需要将 num_classes 从 80 修改为 8 来匹配 cityscapes 数据集标注
  102. # 这个修改包括 `bbox_head` 和 `mask_head`.
  103. roi_head=dict(
  104. bbox_head=[
  105. dict(
  106. type='Shared2FCBBoxHead',
  107. in_channels=256,
  108. fc_out_channels=1024,
  109. roi_feat_size=7,
  110. # 将 COCO 类别修改为 cityscapes 类别
  111. num_classes=8,
  112. bbox_coder=dict(
  113. type='DeltaXYWHBBoxCoder',
  114. target_means=[0., 0., 0., 0.],
  115. target_stds=[0.1, 0.1, 0.2, 0.2]),
  116. reg_class_agnostic=True,
  117. loss_cls=dict(
  118. type='CrossEntropyLoss',
  119. use_sigmoid=False,
  120. loss_weight=1.0),
  121. loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
  122. loss_weight=1.0)),
  123. dict(
  124. type='Shared2FCBBoxHead',
  125. in_channels=256,
  126. fc_out_channels=1024,
  127. roi_feat_size=7,
  128. # 将 COCO 类别修改为 cityscapes 类别
  129. num_classes=8,
  130. bbox_coder=dict(
  131. type='DeltaXYWHBBoxCoder',
  132. target_means=[0., 0., 0., 0.],
  133. target_stds=[0.05, 0.05, 0.1, 0.1]),
  134. reg_class_agnostic=True,
  135. loss_cls=dict(
  136. type='CrossEntropyLoss',
  137. use_sigmoid=False,
  138. loss_weight=1.0),
  139. loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
  140. loss_weight=1.0)),
  141. dict(
  142. type='Shared2FCBBoxHead',
  143. in_channels=256,
  144. fc_out_channels=1024,
  145. roi_feat_size=7,
  146. # 将 COCO 类别修改为 cityscapes 类别
  147. num_classes=8,
  148. bbox_coder=dict(
  149. type='DeltaXYWHBBoxCoder',
  150. target_means=[0., 0., 0., 0.],
  151. target_stds=[0.033, 0.033, 0.067, 0.067]),
  152. reg_class_agnostic=True,
  153. loss_cls=dict(
  154. type='CrossEntropyLoss',
  155. use_sigmoid=False,
  156. loss_weight=1.0),
  157. loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
  158. ],
  159. mask_head=dict(
  160. type='FCNMaskHead',
  161. num_convs=4,
  162. in_channels=256,
  163. conv_out_channels=256,
  164. # 将 COCO 类别修改为 cityscapes 类别
  165. num_classes=8,
  166. loss_mask=dict(
  167. type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))))
  168. # 覆写 `train_pipeline`,然后新增 `AutoAugment` 训练配置
  169. img_norm_cfg = dict(
  170. mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
  171. train_pipeline = [
  172. dict(type='LoadImageFromFile'),
  173. dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
  174. dict(
  175. type='AutoAugment',
  176. policies=[
  177. [dict(
  178. type='Rotate',
  179. level=5,
  180. img_fill_val=(124, 116, 104),
  181. prob=0.5,
  182. scale=1)
  183. ],
  184. [dict(type='Rotate', level=7, img_fill_val=(124, 116, 104)),
  185. dict(
  186. type='Translate',
  187. level=5,
  188. prob=0.5,
  189. img_fill_val=(124, 116, 104))
  190. ],
  191. ]),
  192. dict(
  193. type='Resize', img_scale=[(2048, 800), (2048, 1024)], keep_ratio=True),
  194. dict(type='RandomFlip', flip_ratio=0.5),
  195. dict(type='Normalize', **img_norm_cfg),
  196. dict(type='Pad', size_divisor=32),
  197. dict(type='DefaultFormatBundle'),
  198. dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
  199. ]
  200. # 设置每张显卡的批处理大小,同时设置新的训练 pipeline
  201. data = dict(
  202. samples_per_gpu=1,
  203. workers_per_gpu=3,
  204. # 用新的训练 pipeline 配置覆写 pipeline
  205. train=dict(dataset=dict(pipeline=train_pipeline)))
  206. # 设置优化器
  207. optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
  208. optimizer_config = dict(grad_clip=None)
  209. # 设置定制的学习率策略
  210. lr_config = dict(
  211. policy='step',
  212. warmup='linear',
  213. warmup_iters=500,
  214. warmup_ratio=0.001,
  215. step=[8])
  216. runner = dict(type='EpochBasedRunner', max_epochs=10)
  217. # 我们采用 COCO 预训练过的 Cascade Mask R-CNN R50 模型权重作为初始化权重,可以得到更加稳定的性能
  218. load_from = 'http://download.openmmlab.com/mmdetection/v2.0/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco/cascade_mask_rcnn_r50_fpn_1x_coco_20200203-9d4dcb24.pth'
  219. ```
  220. ## 训练新模型
  221. 为了能够使用新增配置来训练模型,你可以运行如下命令:
  222. ```shell
  223. python tools/train.py configs/cityscapes/cascade_mask_rcnn_r50_augfpn_autoaug_10e_cityscapes.py
  224. ```
  225. 如果想了解更多用法,可以参考 [例子1](1_exist_data_model.md)。
  226. ## 测试和推理
  227. 为了能够测试训练好的模型,你可以运行如下命令:
  228. ```shell
  229. python tools/test.py configs/cityscapes/cascade_mask_rcnn_r50_augfpn_autoaug_10e_cityscapes.py work_dirs/cascade_mask_rcnn_r50_augfpn_autoaug_10e_cityscapes.py/latest.pth --eval bbox segm
  230. ```
  231. 如果想了解更多用法,可以参考 [例子1](1_exist_data_model.md)。

No Description

Contributors (1)