You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

rpn_head.py 11 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import ConvModule
  7. from mmcv.ops import batched_nms
  8. from ..builder import HEADS
  9. from .anchor_head import AnchorHead
  10. @HEADS.register_module()
  11. class RPNHead(AnchorHead):
  12. """RPN head.
  13. Args:
  14. in_channels (int): Number of channels in the input feature map.
  15. init_cfg (dict or list[dict], optional): Initialization config dict.
  16. num_convs (int): Number of convolution layers in the head. Default 1.
  17. """ # noqa: W605
  18. def __init__(self,
  19. in_channels,
  20. init_cfg=dict(type='Normal', layer='Conv2d', std=0.01),
  21. num_convs=1,
  22. **kwargs):
  23. self.num_convs = num_convs
  24. super(RPNHead, self).__init__(
  25. 1, in_channels, init_cfg=init_cfg, **kwargs)
  26. def _init_layers(self):
  27. """Initialize layers of the head."""
  28. if self.num_convs > 1:
  29. rpn_convs = []
  30. for i in range(self.num_convs):
  31. if i == 0:
  32. in_channels = self.in_channels
  33. else:
  34. in_channels = self.feat_channels
  35. # use ``inplace=False`` to avoid error: one of the variables
  36. # needed for gradient computation has been modified by an
  37. # inplace operation.
  38. rpn_convs.append(
  39. ConvModule(
  40. in_channels,
  41. self.feat_channels,
  42. 3,
  43. padding=1,
  44. inplace=False))
  45. self.rpn_conv = nn.Sequential(*rpn_convs)
  46. else:
  47. self.rpn_conv = nn.Conv2d(
  48. self.in_channels, self.feat_channels, 3, padding=1)
  49. self.rpn_cls = nn.Conv2d(self.feat_channels,
  50. self.num_base_priors * self.cls_out_channels,
  51. 1)
  52. self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_base_priors * 4,
  53. 1)
  54. def forward_single(self, x):
  55. """Forward feature map of a single scale level."""
  56. x = self.rpn_conv(x)
  57. x = F.relu(x, inplace=True)
  58. rpn_cls_score = self.rpn_cls(x)
  59. rpn_bbox_pred = self.rpn_reg(x)
  60. return rpn_cls_score, rpn_bbox_pred
  61. def loss(self,
  62. cls_scores,
  63. bbox_preds,
  64. gt_bboxes,
  65. img_metas,
  66. gt_bboxes_ignore=None):
  67. """Compute losses of the head.
  68. Args:
  69. cls_scores (list[Tensor]): Box scores for each scale level
  70. Has shape (N, num_anchors * num_classes, H, W)
  71. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  72. level with shape (N, num_anchors * 4, H, W)
  73. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  74. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  75. img_metas (list[dict]): Meta information of each image, e.g.,
  76. image size, scaling factor, etc.
  77. gt_bboxes_ignore (None | list[Tensor]): specify which bounding
  78. boxes can be ignored when computing the loss.
  79. Returns:
  80. dict[str, Tensor]: A dictionary of loss components.
  81. """
  82. losses = super(RPNHead, self).loss(
  83. cls_scores,
  84. bbox_preds,
  85. gt_bboxes,
  86. None,
  87. img_metas,
  88. gt_bboxes_ignore=gt_bboxes_ignore)
  89. return dict(
  90. loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox'])
  91. def _get_bboxes_single(self,
  92. cls_score_list,
  93. bbox_pred_list,
  94. score_factor_list,
  95. mlvl_anchors,
  96. img_meta,
  97. cfg,
  98. rescale=False,
  99. with_nms=True,
  100. **kwargs):
  101. """Transform outputs of a single image into bbox predictions.
  102. Args:
  103. cls_score_list (list[Tensor]): Box scores from all scale
  104. levels of a single image, each item has shape
  105. (num_anchors * num_classes, H, W).
  106. bbox_pred_list (list[Tensor]): Box energies / deltas from
  107. all scale levels of a single image, each item has
  108. shape (num_anchors * 4, H, W).
  109. score_factor_list (list[Tensor]): Score factor from all scale
  110. levels of a single image. RPN head does not need this value.
  111. mlvl_anchors (list[Tensor]): Anchors of all scale level
  112. each item has shape (num_anchors, 4).
  113. img_meta (dict): Image meta info.
  114. cfg (mmcv.Config): Test / postprocessing configuration,
  115. if None, test_cfg would be used.
  116. rescale (bool): If True, return boxes in original image space.
  117. Default: False.
  118. with_nms (bool): If True, do nms before return boxes.
  119. Default: True.
  120. Returns:
  121. Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
  122. are bounding box positions (tl_x, tl_y, br_x, br_y) and the
  123. 5-th column is a score between 0 and 1.
  124. """
  125. cfg = self.test_cfg if cfg is None else cfg
  126. cfg = copy.deepcopy(cfg)
  127. img_shape = img_meta['img_shape']
  128. # bboxes from different level should be independent during NMS,
  129. # level_ids are used as labels for batched NMS to separate them
  130. level_ids = []
  131. mlvl_scores = []
  132. mlvl_bbox_preds = []
  133. mlvl_valid_anchors = []
  134. nms_pre = cfg.get('nms_pre', -1)
  135. for level_idx in range(len(cls_score_list)):
  136. rpn_cls_score = cls_score_list[level_idx]
  137. rpn_bbox_pred = bbox_pred_list[level_idx]
  138. assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
  139. rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
  140. if self.use_sigmoid_cls:
  141. rpn_cls_score = rpn_cls_score.reshape(-1)
  142. scores = rpn_cls_score.sigmoid()
  143. else:
  144. rpn_cls_score = rpn_cls_score.reshape(-1, 2)
  145. # We set FG labels to [0, num_class-1] and BG label to
  146. # num_class in RPN head since mmdet v2.5, which is unified to
  147. # be consistent with other head since mmdet v2.0. In mmdet v2.0
  148. # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
  149. scores = rpn_cls_score.softmax(dim=1)[:, 0]
  150. rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
  151. anchors = mlvl_anchors[level_idx]
  152. if 0 < nms_pre < scores.shape[0]:
  153. # sort is faster than topk
  154. # _, topk_inds = scores.topk(cfg.nms_pre)
  155. ranked_scores, rank_inds = scores.sort(descending=True)
  156. topk_inds = rank_inds[:nms_pre]
  157. scores = ranked_scores[:nms_pre]
  158. rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
  159. anchors = anchors[topk_inds, :]
  160. mlvl_scores.append(scores)
  161. mlvl_bbox_preds.append(rpn_bbox_pred)
  162. mlvl_valid_anchors.append(anchors)
  163. level_ids.append(
  164. scores.new_full((scores.size(0), ),
  165. level_idx,
  166. dtype=torch.long))
  167. return self._bbox_post_process(mlvl_scores, mlvl_bbox_preds,
  168. mlvl_valid_anchors, level_ids, cfg,
  169. img_shape)
  170. def _bbox_post_process(self, mlvl_scores, mlvl_bboxes, mlvl_valid_anchors,
  171. level_ids, cfg, img_shape, **kwargs):
  172. """bbox post-processing method.
  173. The boxes would be rescaled to the original image scale and do
  174. the nms operation. Usually with_nms is False is used for aug test.
  175. Args:
  176. mlvl_scores (list[Tensor]): Box scores from all scale
  177. levels of a single image, each item has shape
  178. (num_bboxes, num_class).
  179. mlvl_bboxes (list[Tensor]): Decoded bboxes from all scale
  180. levels of a single image, each item has shape (num_bboxes, 4).
  181. mlvl_valid_anchors (list[Tensor]): Anchors of all scale level
  182. each item has shape (num_bboxes, 4).
  183. level_ids (list[Tensor]): Indexes from all scale levels of a
  184. single image, each item has shape (num_bboxes, ).
  185. cfg (mmcv.Config): Test / postprocessing configuration,
  186. if None, test_cfg would be used.
  187. img_shape (tuple(int)): Shape of current image.
  188. Returns:
  189. Tensor: Labeled boxes in shape (n, 5), where the first 4 columns
  190. are bounding box positions (tl_x, tl_y, br_x, br_y) and the
  191. 5-th column is a score between 0 and 1.
  192. """
  193. scores = torch.cat(mlvl_scores)
  194. anchors = torch.cat(mlvl_valid_anchors)
  195. rpn_bbox_pred = torch.cat(mlvl_bboxes)
  196. proposals = self.bbox_coder.decode(
  197. anchors, rpn_bbox_pred, max_shape=img_shape)
  198. ids = torch.cat(level_ids)
  199. if cfg.min_bbox_size >= 0:
  200. w = proposals[:, 2] - proposals[:, 0]
  201. h = proposals[:, 3] - proposals[:, 1]
  202. valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
  203. if not valid_mask.all():
  204. proposals = proposals[valid_mask]
  205. scores = scores[valid_mask]
  206. ids = ids[valid_mask]
  207. if proposals.numel() > 0:
  208. dets, _ = batched_nms(proposals, scores, ids, cfg.nms)
  209. else:
  210. return proposals.new_zeros(0, 5)
  211. return dets[:cfg.max_per_img]
  212. def onnx_export(self, x, img_metas):
  213. """Test without augmentation.
  214. Args:
  215. x (tuple[Tensor]): Features from the upstream network, each is
  216. a 4D-tensor.
  217. img_metas (list[dict]): Meta info of each image.
  218. Returns:
  219. Tensor: dets of shape [N, num_det, 5].
  220. """
  221. cls_scores, bbox_preds = self(x)
  222. assert len(cls_scores) == len(bbox_preds)
  223. batch_bboxes, batch_scores = super(RPNHead, self).onnx_export(
  224. cls_scores, bbox_preds, img_metas=img_metas, with_nms=False)
  225. # Use ONNX::NonMaxSuppression in deployment
  226. from mmdet.core.export import add_dummy_nms_for_onnx
  227. cfg = copy.deepcopy(self.test_cfg)
  228. score_threshold = cfg.nms.get('score_thr', 0.0)
  229. nms_pre = cfg.get('deploy_nms_pre', -1)
  230. # Different from the normal forward doing NMS level by level,
  231. # we do NMS across all levels when exporting ONNX.
  232. dets, _ = add_dummy_nms_for_onnx(batch_bboxes, batch_scores,
  233. cfg.max_per_img,
  234. cfg.nms.iou_threshold,
  235. score_threshold, nms_pre,
  236. cfg.max_per_img)
  237. return dets

No Description

Contributors (3)