You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

finetune.md 3.9 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Tutorial 7: Finetuning Models
  2. Detectors pre-trained on the COCO dataset can serve as a good pre-trained model for other datasets, e.g., CityScapes and KITTI Dataset.
  3. This tutorial provides instruction for users to use the models provided in the [Model Zoo](../model_zoo.md) for other datasets to obtain better performance.
  4. There are two steps to finetune a model on a new dataset.
  5. - Add support for the new dataset following [Tutorial 2: Customize Datasets](customize_dataset.md).
  6. - Modify the configs as will be discussed in this tutorial.
  7. Take the finetuning process on Cityscapes Dataset as an example, the users need to modify five parts in the config.
  8. ## Inherit base configs
  9. To release the burden and reduce bugs in writing the whole configs, MMDetection V2.0 support inheriting configs from multiple existing configs. To finetune a Mask RCNN model, the new config needs to inherit
  10. `_base_/models/mask_rcnn_r50_fpn.py` to build the basic structure of the model. To use the Cityscapes Dataset, the new config can also simply inherit `_base_/datasets/cityscapes_instance.py`. For runtime settings such as training schedules, the new config needs to inherit `_base_/default_runtime.py`. This configs are in the `configs` directory and the users can also choose to write the whole contents rather than use inheritance.
  11. ```python
  12. _base_ = [
  13. '../_base_/models/mask_rcnn_r50_fpn.py',
  14. '../_base_/datasets/cityscapes_instance.py', '../_base_/default_runtime.py'
  15. ]
  16. ```
  17. ## Modify head
  18. Then the new config needs to modify the head according to the class numbers of the new datasets. By only changing `num_classes` in the roi_head, the weights of the pre-trained models are mostly reused except the final prediction head.
  19. ```python
  20. model = dict(
  21. pretrained=None,
  22. roi_head=dict(
  23. bbox_head=dict(
  24. type='Shared2FCBBoxHead',
  25. in_channels=256,
  26. fc_out_channels=1024,
  27. roi_feat_size=7,
  28. num_classes=8,
  29. bbox_coder=dict(
  30. type='DeltaXYWHBBoxCoder',
  31. target_means=[0., 0., 0., 0.],
  32. target_stds=[0.1, 0.1, 0.2, 0.2]),
  33. reg_class_agnostic=False,
  34. loss_cls=dict(
  35. type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
  36. loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
  37. mask_head=dict(
  38. type='FCNMaskHead',
  39. num_convs=4,
  40. in_channels=256,
  41. conv_out_channels=256,
  42. num_classes=8,
  43. loss_mask=dict(
  44. type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))))
  45. ```
  46. ## Modify dataset
  47. The users may also need to prepare the dataset and write the configs about dataset. MMDetection V2.0 already support VOC, WIDER FACE, COCO and Cityscapes Dataset.
  48. ## Modify training schedule
  49. The finetuning hyperparameters vary from the default schedule. It usually requires smaller learning rate and less training epochs
  50. ```python
  51. # optimizer
  52. # lr is set for a batch size of 8
  53. optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
  54. optimizer_config = dict(grad_clip=None)
  55. # learning policy
  56. lr_config = dict(
  57. policy='step',
  58. warmup='linear',
  59. warmup_iters=500,
  60. warmup_ratio=0.001,
  61. step=[7])
  62. # the max_epochs and step in lr_config need specifically tuned for the customized dataset
  63. runner = dict(max_epochs=8)
  64. log_config = dict(interval=100)
  65. ```
  66. ## Use pre-trained model
  67. To use the pre-trained model, the new config add the link of pre-trained models in the `load_from`. The users might need to download the model weights before training to avoid the download time during training.
  68. ```python
  69. load_from = 'https://download.openmmlab.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth' # noqa
  70. ```

No Description

Contributors (1)