You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

customize_models.md 10 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. # Tutorial 4: Customize Models
  2. We basically categorize model components into 5 types.
  3. - backbone: usually an FCN network to extract feature maps, e.g., ResNet, MobileNet.
  4. - neck: the component between backbones and heads, e.g., FPN, PAFPN.
  5. - head: the component for specific tasks, e.g., bbox prediction and mask prediction.
  6. - roi extractor: the part for extracting RoI features from feature maps, e.g., RoI Align.
  7. - loss: the component in head for calculating losses, e.g., FocalLoss, L1Loss, and GHMLoss.
  8. ## Develop new components
  9. ### Add a new backbone
  10. Here we show how to develop new components with an example of MobileNet.
  11. #### 1. Define a new backbone (e.g. MobileNet)
  12. Create a new file `mmdet/models/backbones/mobilenet.py`.
  13. ```python
  14. import torch.nn as nn
  15. from ..builder import BACKBONES
  16. @BACKBONES.register_module()
  17. class MobileNet(nn.Module):
  18. def __init__(self, arg1, arg2):
  19. pass
  20. def forward(self, x): # should return a tuple
  21. pass
  22. ```
  23. #### 2. Import the module
  24. You can either add the following line to `mmdet/models/backbones/__init__.py`
  25. ```python
  26. from .mobilenet import MobileNet
  27. ```
  28. or alternatively add
  29. ```python
  30. custom_imports = dict(
  31. imports=['mmdet.models.backbones.mobilenet'],
  32. allow_failed_imports=False)
  33. ```
  34. to the config file to avoid modifying the original code.
  35. #### 3. Use the backbone in your config file
  36. ```python
  37. model = dict(
  38. ...
  39. backbone=dict(
  40. type='MobileNet',
  41. arg1=xxx,
  42. arg2=xxx),
  43. ...
  44. ```
  45. ### Add new necks
  46. #### 1. Define a neck (e.g. PAFPN)
  47. Create a new file `mmdet/models/necks/pafpn.py`.
  48. ```python
  49. from ..builder import NECKS
  50. @NECKS.register_module()
  51. class PAFPN(nn.Module):
  52. def __init__(self,
  53. in_channels,
  54. out_channels,
  55. num_outs,
  56. start_level=0,
  57. end_level=-1,
  58. add_extra_convs=False):
  59. pass
  60. def forward(self, inputs):
  61. # implementation is ignored
  62. pass
  63. ```
  64. #### 2. Import the module
  65. You can either add the following line to `mmdet/models/necks/__init__.py`,
  66. ```python
  67. from .pafpn import PAFPN
  68. ```
  69. or alternatively add
  70. ```python
  71. custom_imports = dict(
  72. imports=['mmdet.models.necks.pafpn.py'],
  73. allow_failed_imports=False)
  74. ```
  75. to the config file and avoid modifying the original code.
  76. #### 3. Modify the config file
  77. ```python
  78. neck=dict(
  79. type='PAFPN',
  80. in_channels=[256, 512, 1024, 2048],
  81. out_channels=256,
  82. num_outs=5)
  83. ```
  84. ### Add new heads
  85. Here we show how to develop a new head with the example of [Double Head R-CNN](https://arxiv.org/abs/1904.06493) as the following.
  86. First, add a new bbox head in `mmdet/models/roi_heads/bbox_heads/double_bbox_head.py`.
  87. Double Head R-CNN implements a new bbox head for object detection.
  88. To implement a bbox head, basically we need to implement three functions of the new module as the following.
  89. ```python
  90. from mmdet.models.builder import HEADS
  91. from .bbox_head import BBoxHead
  92. @HEADS.register_module()
  93. class DoubleConvFCBBoxHead(BBoxHead):
  94. r"""Bbox head used in Double-Head R-CNN
  95. /-> cls
  96. /-> shared convs ->
  97. \-> reg
  98. roi features
  99. /-> cls
  100. \-> shared fc ->
  101. \-> reg
  102. """ # noqa: W605
  103. def __init__(self,
  104. num_convs=0,
  105. num_fcs=0,
  106. conv_out_channels=1024,
  107. fc_out_channels=1024,
  108. conv_cfg=None,
  109. norm_cfg=dict(type='BN'),
  110. **kwargs):
  111. kwargs.setdefault('with_avg_pool', True)
  112. super(DoubleConvFCBBoxHead, self).__init__(**kwargs)
  113. def forward(self, x_cls, x_reg):
  114. ```
  115. Second, implement a new RoI Head if it is necessary. We plan to inherit the new `DoubleHeadRoIHead` from `StandardRoIHead`. We can find that a `StandardRoIHead` already implements the following functions.
  116. ```python
  117. import torch
  118. from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
  119. from ..builder import HEADS, build_head, build_roi_extractor
  120. from .base_roi_head import BaseRoIHead
  121. from .test_mixins import BBoxTestMixin, MaskTestMixin
  122. @HEADS.register_module()
  123. class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
  124. """Simplest base roi head including one bbox head and one mask head.
  125. """
  126. def init_assigner_sampler(self):
  127. def init_bbox_head(self, bbox_roi_extractor, bbox_head):
  128. def init_mask_head(self, mask_roi_extractor, mask_head):
  129. def forward_dummy(self, x, proposals):
  130. def forward_train(self,
  131. x,
  132. img_metas,
  133. proposal_list,
  134. gt_bboxes,
  135. gt_labels,
  136. gt_bboxes_ignore=None,
  137. gt_masks=None):
  138. def _bbox_forward(self, x, rois):
  139. def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
  140. img_metas):
  141. def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks,
  142. img_metas):
  143. def _mask_forward(self, x, rois=None, pos_inds=None, bbox_feats=None):
  144. def simple_test(self,
  145. x,
  146. proposal_list,
  147. img_metas,
  148. proposals=None,
  149. rescale=False):
  150. """Test without augmentation."""
  151. ```
  152. Double Head's modification is mainly in the bbox_forward logic, and it inherits other logics from the `StandardRoIHead`.
  153. In the `mmdet/models/roi_heads/double_roi_head.py`, we implement the new RoI Head as the following:
  154. ```python
  155. from ..builder import HEADS
  156. from .standard_roi_head import StandardRoIHead
  157. @HEADS.register_module()
  158. class DoubleHeadRoIHead(StandardRoIHead):
  159. """RoI head for Double Head RCNN
  160. https://arxiv.org/abs/1904.06493
  161. """
  162. def __init__(self, reg_roi_scale_factor, **kwargs):
  163. super(DoubleHeadRoIHead, self).__init__(**kwargs)
  164. self.reg_roi_scale_factor = reg_roi_scale_factor
  165. def _bbox_forward(self, x, rois):
  166. bbox_cls_feats = self.bbox_roi_extractor(
  167. x[:self.bbox_roi_extractor.num_inputs], rois)
  168. bbox_reg_feats = self.bbox_roi_extractor(
  169. x[:self.bbox_roi_extractor.num_inputs],
  170. rois,
  171. roi_scale_factor=self.reg_roi_scale_factor)
  172. if self.with_shared_head:
  173. bbox_cls_feats = self.shared_head(bbox_cls_feats)
  174. bbox_reg_feats = self.shared_head(bbox_reg_feats)
  175. cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)
  176. bbox_results = dict(
  177. cls_score=cls_score,
  178. bbox_pred=bbox_pred,
  179. bbox_feats=bbox_cls_feats)
  180. return bbox_results
  181. ```
  182. Last, the users need to add the module in
  183. `mmdet/models/bbox_heads/__init__.py` and `mmdet/models/roi_heads/__init__.py` thus the corresponding registry could find and load them.
  184. Alternatively, the users can add
  185. ```python
  186. custom_imports=dict(
  187. imports=['mmdet.models.roi_heads.double_roi_head', 'mmdet.models.bbox_heads.double_bbox_head'])
  188. ```
  189. to the config file and achieve the same goal.
  190. The config file of Double Head R-CNN is as the following
  191. ```python
  192. _base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
  193. model = dict(
  194. roi_head=dict(
  195. type='DoubleHeadRoIHead',
  196. reg_roi_scale_factor=1.3,
  197. bbox_head=dict(
  198. _delete_=True,
  199. type='DoubleConvFCBBoxHead',
  200. num_convs=4,
  201. num_fcs=2,
  202. in_channels=256,
  203. conv_out_channels=1024,
  204. fc_out_channels=1024,
  205. roi_feat_size=7,
  206. num_classes=80,
  207. bbox_coder=dict(
  208. type='DeltaXYWHBBoxCoder',
  209. target_means=[0., 0., 0., 0.],
  210. target_stds=[0.1, 0.1, 0.2, 0.2]),
  211. reg_class_agnostic=False,
  212. loss_cls=dict(
  213. type='CrossEntropyLoss', use_sigmoid=False, loss_weight=2.0),
  214. loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=2.0))))
  215. ```
  216. Since MMDetection 2.0, the config system supports to inherit configs such that the users can focus on the modification.
  217. The Double Head R-CNN mainly uses a new DoubleHeadRoIHead and a new
  218. `DoubleConvFCBBoxHead`, the arguments are set according to the `__init__` function of each module.
  219. ### Add new loss
  220. Assume you want to add a new loss as `MyLoss`, for bounding box regression.
  221. To add a new loss function, the users need implement it in `mmdet/models/losses/my_loss.py`.
  222. The decorator `weighted_loss` enable the loss to be weighted for each element.
  223. ```python
  224. import torch
  225. import torch.nn as nn
  226. from ..builder import LOSSES
  227. from .utils import weighted_loss
  228. @weighted_loss
  229. def my_loss(pred, target):
  230. assert pred.size() == target.size() and target.numel() > 0
  231. loss = torch.abs(pred - target)
  232. return loss
  233. @LOSSES.register_module()
  234. class MyLoss(nn.Module):
  235. def __init__(self, reduction='mean', loss_weight=1.0):
  236. super(MyLoss, self).__init__()
  237. self.reduction = reduction
  238. self.loss_weight = loss_weight
  239. def forward(self,
  240. pred,
  241. target,
  242. weight=None,
  243. avg_factor=None,
  244. reduction_override=None):
  245. assert reduction_override in (None, 'none', 'mean', 'sum')
  246. reduction = (
  247. reduction_override if reduction_override else self.reduction)
  248. loss_bbox = self.loss_weight * my_loss(
  249. pred, target, weight, reduction=reduction, avg_factor=avg_factor)
  250. return loss_bbox
  251. ```
  252. Then the users need to add it in the `mmdet/models/losses/__init__.py`.
  253. ```python
  254. from .my_loss import MyLoss, my_loss
  255. ```
  256. Alternatively, you can add
  257. ```python
  258. custom_imports=dict(
  259. imports=['mmdet.models.losses.my_loss'])
  260. ```
  261. to the config file and achieve the same goal.
  262. To use it, modify the `loss_xxx` field.
  263. Since MyLoss is for regression, you need to modify the `loss_bbox` field in the head.
  264. ```python
  265. loss_bbox=dict(type='MyLoss', loss_weight=1.0))
  266. ```

No Description

Contributors (3)