|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363 |
- # Tutorial 4: Customize Models
-
- We basically categorize model components into 5 types.
-
- - backbone: usually an FCN network to extract feature maps, e.g., ResNet, MobileNet.
- - neck: the component between backbones and heads, e.g., FPN, PAFPN.
- - head: the component for specific tasks, e.g., bbox prediction and mask prediction.
- - roi extractor: the part for extracting RoI features from feature maps, e.g., RoI Align.
- - loss: the component in head for calculating losses, e.g., FocalLoss, L1Loss, and GHMLoss.
-
- ## Develop new components
-
- ### Add a new backbone
-
- Here we show how to develop new components with an example of MobileNet.
-
- #### 1. Define a new backbone (e.g. MobileNet)
-
- Create a new file `mmdet/models/backbones/mobilenet.py`.
-
- ```python
- import torch.nn as nn
-
- from ..builder import BACKBONES
-
-
- @BACKBONES.register_module()
- class MobileNet(nn.Module):
-
- def __init__(self, arg1, arg2):
- pass
-
- def forward(self, x): # should return a tuple
- pass
- ```
-
- #### 2. Import the module
-
- You can either add the following line to `mmdet/models/backbones/__init__.py`
-
- ```python
- from .mobilenet import MobileNet
- ```
-
- or alternatively add
-
- ```python
- custom_imports = dict(
- imports=['mmdet.models.backbones.mobilenet'],
- allow_failed_imports=False)
- ```
-
- to the config file to avoid modifying the original code.
-
- #### 3. Use the backbone in your config file
-
- ```python
- model = dict(
- ...
- backbone=dict(
- type='MobileNet',
- arg1=xxx,
- arg2=xxx),
- ...
- ```
-
- ### Add new necks
-
- #### 1. Define a neck (e.g. PAFPN)
-
- Create a new file `mmdet/models/necks/pafpn.py`.
-
- ```python
- from ..builder import NECKS
-
- @NECKS.register_module()
- class PAFPN(nn.Module):
-
- def __init__(self,
- in_channels,
- out_channels,
- num_outs,
- start_level=0,
- end_level=-1,
- add_extra_convs=False):
- pass
-
- def forward(self, inputs):
- # implementation is ignored
- pass
- ```
-
- #### 2. Import the module
-
- You can either add the following line to `mmdet/models/necks/__init__.py`,
-
- ```python
- from .pafpn import PAFPN
- ```
-
- or alternatively add
-
- ```python
- custom_imports = dict(
- imports=['mmdet.models.necks.pafpn.py'],
- allow_failed_imports=False)
- ```
-
- to the config file and avoid modifying the original code.
-
- #### 3. Modify the config file
-
- ```python
- neck=dict(
- type='PAFPN',
- in_channels=[256, 512, 1024, 2048],
- out_channels=256,
- num_outs=5)
- ```
-
- ### Add new heads
-
- Here we show how to develop a new head with the example of [Double Head R-CNN](https://arxiv.org/abs/1904.06493) as the following.
-
- First, add a new bbox head in `mmdet/models/roi_heads/bbox_heads/double_bbox_head.py`.
- Double Head R-CNN implements a new bbox head for object detection.
- To implement a bbox head, basically we need to implement three functions of the new module as the following.
-
- ```python
- from mmdet.models.builder import HEADS
- from .bbox_head import BBoxHead
-
- @HEADS.register_module()
- class DoubleConvFCBBoxHead(BBoxHead):
- r"""Bbox head used in Double-Head R-CNN
-
- /-> cls
- /-> shared convs ->
- \-> reg
- roi features
- /-> cls
- \-> shared fc ->
- \-> reg
- """ # noqa: W605
-
- def __init__(self,
- num_convs=0,
- num_fcs=0,
- conv_out_channels=1024,
- fc_out_channels=1024,
- conv_cfg=None,
- norm_cfg=dict(type='BN'),
- **kwargs):
- kwargs.setdefault('with_avg_pool', True)
- super(DoubleConvFCBBoxHead, self).__init__(**kwargs)
-
-
- def forward(self, x_cls, x_reg):
-
- ```
-
- Second, implement a new RoI Head if it is necessary. We plan to inherit the new `DoubleHeadRoIHead` from `StandardRoIHead`. We can find that a `StandardRoIHead` already implements the following functions.
-
- ```python
- import torch
-
- from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
- from ..builder import HEADS, build_head, build_roi_extractor
- from .base_roi_head import BaseRoIHead
- from .test_mixins import BBoxTestMixin, MaskTestMixin
-
-
- @HEADS.register_module()
- class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
- """Simplest base roi head including one bbox head and one mask head.
- """
-
- def init_assigner_sampler(self):
-
- def init_bbox_head(self, bbox_roi_extractor, bbox_head):
-
- def init_mask_head(self, mask_roi_extractor, mask_head):
-
-
- def forward_dummy(self, x, proposals):
-
-
- def forward_train(self,
- x,
- img_metas,
- proposal_list,
- gt_bboxes,
- gt_labels,
- gt_bboxes_ignore=None,
- gt_masks=None):
-
- def _bbox_forward(self, x, rois):
-
- def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
- img_metas):
-
- def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks,
- img_metas):
-
- def _mask_forward(self, x, rois=None, pos_inds=None, bbox_feats=None):
-
-
- def simple_test(self,
- x,
- proposal_list,
- img_metas,
- proposals=None,
- rescale=False):
- """Test without augmentation."""
-
- ```
-
- Double Head's modification is mainly in the bbox_forward logic, and it inherits other logics from the `StandardRoIHead`.
- In the `mmdet/models/roi_heads/double_roi_head.py`, we implement the new RoI Head as the following:
-
- ```python
- from ..builder import HEADS
- from .standard_roi_head import StandardRoIHead
-
-
- @HEADS.register_module()
- class DoubleHeadRoIHead(StandardRoIHead):
- """RoI head for Double Head RCNN
-
- https://arxiv.org/abs/1904.06493
- """
-
- def __init__(self, reg_roi_scale_factor, **kwargs):
- super(DoubleHeadRoIHead, self).__init__(**kwargs)
- self.reg_roi_scale_factor = reg_roi_scale_factor
-
- def _bbox_forward(self, x, rois):
- bbox_cls_feats = self.bbox_roi_extractor(
- x[:self.bbox_roi_extractor.num_inputs], rois)
- bbox_reg_feats = self.bbox_roi_extractor(
- x[:self.bbox_roi_extractor.num_inputs],
- rois,
- roi_scale_factor=self.reg_roi_scale_factor)
- if self.with_shared_head:
- bbox_cls_feats = self.shared_head(bbox_cls_feats)
- bbox_reg_feats = self.shared_head(bbox_reg_feats)
- cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)
-
- bbox_results = dict(
- cls_score=cls_score,
- bbox_pred=bbox_pred,
- bbox_feats=bbox_cls_feats)
- return bbox_results
- ```
-
- Last, the users need to add the module in
- `mmdet/models/bbox_heads/__init__.py` and `mmdet/models/roi_heads/__init__.py` thus the corresponding registry could find and load them.
-
- Alternatively, the users can add
-
- ```python
- custom_imports=dict(
- imports=['mmdet.models.roi_heads.double_roi_head', 'mmdet.models.bbox_heads.double_bbox_head'])
- ```
-
- to the config file and achieve the same goal.
-
- The config file of Double Head R-CNN is as the following
-
- ```python
- _base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
- model = dict(
- roi_head=dict(
- type='DoubleHeadRoIHead',
- reg_roi_scale_factor=1.3,
- bbox_head=dict(
- _delete_=True,
- type='DoubleConvFCBBoxHead',
- num_convs=4,
- num_fcs=2,
- in_channels=256,
- conv_out_channels=1024,
- fc_out_channels=1024,
- roi_feat_size=7,
- num_classes=80,
- bbox_coder=dict(
- type='DeltaXYWHBBoxCoder',
- target_means=[0., 0., 0., 0.],
- target_stds=[0.1, 0.1, 0.2, 0.2]),
- reg_class_agnostic=False,
- loss_cls=dict(
- type='CrossEntropyLoss', use_sigmoid=False, loss_weight=2.0),
- loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=2.0))))
-
- ```
-
- Since MMDetection 2.0, the config system supports to inherit configs such that the users can focus on the modification.
- The Double Head R-CNN mainly uses a new DoubleHeadRoIHead and a new
- `DoubleConvFCBBoxHead`, the arguments are set according to the `__init__` function of each module.
-
- ### Add new loss
-
- Assume you want to add a new loss as `MyLoss`, for bounding box regression.
- To add a new loss function, the users need implement it in `mmdet/models/losses/my_loss.py`.
- The decorator `weighted_loss` enable the loss to be weighted for each element.
-
- ```python
- import torch
- import torch.nn as nn
-
- from ..builder import LOSSES
- from .utils import weighted_loss
-
- @weighted_loss
- def my_loss(pred, target):
- assert pred.size() == target.size() and target.numel() > 0
- loss = torch.abs(pred - target)
- return loss
-
- @LOSSES.register_module()
- class MyLoss(nn.Module):
-
- def __init__(self, reduction='mean', loss_weight=1.0):
- super(MyLoss, self).__init__()
- self.reduction = reduction
- self.loss_weight = loss_weight
-
- def forward(self,
- pred,
- target,
- weight=None,
- avg_factor=None,
- reduction_override=None):
- assert reduction_override in (None, 'none', 'mean', 'sum')
- reduction = (
- reduction_override if reduction_override else self.reduction)
- loss_bbox = self.loss_weight * my_loss(
- pred, target, weight, reduction=reduction, avg_factor=avg_factor)
- return loss_bbox
- ```
-
- Then the users need to add it in the `mmdet/models/losses/__init__.py`.
-
- ```python
- from .my_loss import MyLoss, my_loss
-
- ```
-
- Alternatively, you can add
-
- ```python
- custom_imports=dict(
- imports=['mmdet.models.losses.my_loss'])
- ```
-
- to the config file and achieve the same goal.
-
- To use it, modify the `loss_xxx` field.
- Since MyLoss is for regression, you need to modify the `loss_bbox` field in the head.
-
- ```python
- loss_bbox=dict(type='MyLoss', loss_weight=1.0))
- ```
|