You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ga_rpn_head.py 7.1 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import copy
  3. import warnings
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from mmcv import ConfigDict
  8. from mmcv.ops import nms
  9. from ..builder import HEADS
  10. from .guided_anchor_head import GuidedAnchorHead
  11. @HEADS.register_module()
  12. class GARPNHead(GuidedAnchorHead):
  13. """Guided-Anchor-based RPN head."""
  14. def __init__(self,
  15. in_channels,
  16. init_cfg=dict(
  17. type='Normal',
  18. layer='Conv2d',
  19. std=0.01,
  20. override=dict(
  21. type='Normal',
  22. name='conv_loc',
  23. std=0.01,
  24. bias_prob=0.01)),
  25. **kwargs):
  26. super(GARPNHead, self).__init__(
  27. 1, in_channels, init_cfg=init_cfg, **kwargs)
  28. def _init_layers(self):
  29. """Initialize layers of the head."""
  30. self.rpn_conv = nn.Conv2d(
  31. self.in_channels, self.feat_channels, 3, padding=1)
  32. super(GARPNHead, self)._init_layers()
  33. def forward_single(self, x):
  34. """Forward feature of a single scale level."""
  35. x = self.rpn_conv(x)
  36. x = F.relu(x, inplace=True)
  37. (cls_score, bbox_pred, shape_pred,
  38. loc_pred) = super(GARPNHead, self).forward_single(x)
  39. return cls_score, bbox_pred, shape_pred, loc_pred
  40. def loss(self,
  41. cls_scores,
  42. bbox_preds,
  43. shape_preds,
  44. loc_preds,
  45. gt_bboxes,
  46. img_metas,
  47. gt_bboxes_ignore=None):
  48. losses = super(GARPNHead, self).loss(
  49. cls_scores,
  50. bbox_preds,
  51. shape_preds,
  52. loc_preds,
  53. gt_bboxes,
  54. None,
  55. img_metas,
  56. gt_bboxes_ignore=gt_bboxes_ignore)
  57. return dict(
  58. loss_rpn_cls=losses['loss_cls'],
  59. loss_rpn_bbox=losses['loss_bbox'],
  60. loss_anchor_shape=losses['loss_shape'],
  61. loss_anchor_loc=losses['loss_loc'])
  62. def _get_bboxes_single(self,
  63. cls_scores,
  64. bbox_preds,
  65. mlvl_anchors,
  66. mlvl_masks,
  67. img_shape,
  68. scale_factor,
  69. cfg,
  70. rescale=False):
  71. cfg = self.test_cfg if cfg is None else cfg
  72. cfg = copy.deepcopy(cfg)
  73. # deprecate arguments warning
  74. if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
  75. warnings.warn(
  76. 'In rpn_proposal or test_cfg, '
  77. 'nms_thr has been moved to a dict named nms as '
  78. 'iou_threshold, max_num has been renamed as max_per_img, '
  79. 'name of original arguments and the way to specify '
  80. 'iou_threshold of NMS will be deprecated.')
  81. if 'nms' not in cfg:
  82. cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
  83. if 'max_num' in cfg:
  84. if 'max_per_img' in cfg:
  85. assert cfg.max_num == cfg.max_per_img, f'You ' \
  86. f'set max_num and max_per_img at the same time, ' \
  87. f'but get {cfg.max_num} ' \
  88. f'and {cfg.max_per_img} respectively' \
  89. 'Please delete max_num which will be deprecated.'
  90. else:
  91. cfg.max_per_img = cfg.max_num
  92. if 'nms_thr' in cfg:
  93. assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set ' \
  94. f'iou_threshold in nms and ' \
  95. f'nms_thr at the same time, but get ' \
  96. f'{cfg.nms.iou_threshold} and {cfg.nms_thr}' \
  97. f' respectively. Please delete the ' \
  98. f'nms_thr which will be deprecated.'
  99. assert cfg.nms.get('type', 'nms') == 'nms', 'GARPNHead only support ' \
  100. 'naive nms.'
  101. mlvl_proposals = []
  102. for idx in range(len(cls_scores)):
  103. rpn_cls_score = cls_scores[idx]
  104. rpn_bbox_pred = bbox_preds[idx]
  105. anchors = mlvl_anchors[idx]
  106. mask = mlvl_masks[idx]
  107. assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
  108. # if no location is kept, end.
  109. if mask.sum() == 0:
  110. continue
  111. rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
  112. if self.use_sigmoid_cls:
  113. rpn_cls_score = rpn_cls_score.reshape(-1)
  114. scores = rpn_cls_score.sigmoid()
  115. else:
  116. rpn_cls_score = rpn_cls_score.reshape(-1, 2)
  117. # remind that we set FG labels to [0, num_class-1]
  118. # since mmdet v2.0
  119. # BG cat_id: num_class
  120. scores = rpn_cls_score.softmax(dim=1)[:, :-1]
  121. # filter scores, bbox_pred w.r.t. mask.
  122. # anchors are filtered in get_anchors() beforehand.
  123. scores = scores[mask]
  124. rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1,
  125. 4)[mask, :]
  126. if scores.dim() == 0:
  127. rpn_bbox_pred = rpn_bbox_pred.unsqueeze(0)
  128. anchors = anchors.unsqueeze(0)
  129. scores = scores.unsqueeze(0)
  130. # filter anchors, bbox_pred, scores w.r.t. scores
  131. if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
  132. _, topk_inds = scores.topk(cfg.nms_pre)
  133. rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
  134. anchors = anchors[topk_inds, :]
  135. scores = scores[topk_inds]
  136. # get proposals w.r.t. anchors and rpn_bbox_pred
  137. proposals = self.bbox_coder.decode(
  138. anchors, rpn_bbox_pred, max_shape=img_shape)
  139. # filter out too small bboxes
  140. if cfg.min_bbox_size >= 0:
  141. w = proposals[:, 2] - proposals[:, 0]
  142. h = proposals[:, 3] - proposals[:, 1]
  143. valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
  144. if not valid_mask.all():
  145. proposals = proposals[valid_mask]
  146. scores = scores[valid_mask]
  147. # NMS in current level
  148. proposals, _ = nms(proposals, scores, cfg.nms.iou_threshold)
  149. proposals = proposals[:cfg.nms_post, :]
  150. mlvl_proposals.append(proposals)
  151. proposals = torch.cat(mlvl_proposals, 0)
  152. if cfg.get('nms_across_levels', False):
  153. # NMS across multi levels
  154. proposals, _ = nms(proposals[:, :4], proposals[:, -1],
  155. cfg.nms.iou_threshold)
  156. proposals = proposals[:cfg.max_per_img, :]
  157. else:
  158. scores = proposals[:, 4]
  159. num = min(cfg.max_per_img, proposals.shape[0])
  160. _, topk_inds = scores.topk(num)
  161. proposals = proposals[topk_inds, :]
  162. return proposals

No Description

Contributors (1)