You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fovea_head.py 16 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.cnn import ConvModule
  6. from mmcv.ops import DeformConv2d
  7. from mmcv.runner import BaseModule
  8. from mmdet.core import multi_apply
  9. from mmdet.core.utils import filter_scores_and_topk
  10. from ..builder import HEADS
  11. from .anchor_free_head import AnchorFreeHead
  12. INF = 1e8
  13. class FeatureAlign(BaseModule):
  14. def __init__(self,
  15. in_channels,
  16. out_channels,
  17. kernel_size=3,
  18. deform_groups=4,
  19. init_cfg=dict(
  20. type='Normal',
  21. layer='Conv2d',
  22. std=0.1,
  23. override=dict(
  24. type='Normal', name='conv_adaption', std=0.01))):
  25. super(FeatureAlign, self).__init__(init_cfg)
  26. offset_channels = kernel_size * kernel_size * 2
  27. self.conv_offset = nn.Conv2d(
  28. 4, deform_groups * offset_channels, 1, bias=False)
  29. self.conv_adaption = DeformConv2d(
  30. in_channels,
  31. out_channels,
  32. kernel_size=kernel_size,
  33. padding=(kernel_size - 1) // 2,
  34. deform_groups=deform_groups)
  35. self.relu = nn.ReLU(inplace=True)
  36. def forward(self, x, shape):
  37. offset = self.conv_offset(shape)
  38. x = self.relu(self.conv_adaption(x, offset))
  39. return x
  40. @HEADS.register_module()
  41. class FoveaHead(AnchorFreeHead):
  42. """FoveaBox: Beyond Anchor-based Object Detector
  43. https://arxiv.org/abs/1904.03797
  44. """
  45. def __init__(self,
  46. num_classes,
  47. in_channels,
  48. base_edge_list=(16, 32, 64, 128, 256),
  49. scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128,
  50. 512)),
  51. sigma=0.4,
  52. with_deform=False,
  53. deform_groups=4,
  54. init_cfg=dict(
  55. type='Normal',
  56. layer='Conv2d',
  57. std=0.01,
  58. override=dict(
  59. type='Normal',
  60. name='conv_cls',
  61. std=0.01,
  62. bias_prob=0.01)),
  63. **kwargs):
  64. self.base_edge_list = base_edge_list
  65. self.scale_ranges = scale_ranges
  66. self.sigma = sigma
  67. self.with_deform = with_deform
  68. self.deform_groups = deform_groups
  69. super().__init__(num_classes, in_channels, init_cfg=init_cfg, **kwargs)
  70. def _init_layers(self):
  71. # box branch
  72. super()._init_reg_convs()
  73. self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
  74. # cls branch
  75. if not self.with_deform:
  76. super()._init_cls_convs()
  77. self.conv_cls = nn.Conv2d(
  78. self.feat_channels, self.cls_out_channels, 3, padding=1)
  79. else:
  80. self.cls_convs = nn.ModuleList()
  81. self.cls_convs.append(
  82. ConvModule(
  83. self.feat_channels, (self.feat_channels * 4),
  84. 3,
  85. stride=1,
  86. padding=1,
  87. conv_cfg=self.conv_cfg,
  88. norm_cfg=self.norm_cfg,
  89. bias=self.norm_cfg is None))
  90. self.cls_convs.append(
  91. ConvModule((self.feat_channels * 4), (self.feat_channels * 4),
  92. 1,
  93. stride=1,
  94. padding=0,
  95. conv_cfg=self.conv_cfg,
  96. norm_cfg=self.norm_cfg,
  97. bias=self.norm_cfg is None))
  98. self.feature_adaption = FeatureAlign(
  99. self.feat_channels,
  100. self.feat_channels,
  101. kernel_size=3,
  102. deform_groups=self.deform_groups)
  103. self.conv_cls = nn.Conv2d(
  104. int(self.feat_channels * 4),
  105. self.cls_out_channels,
  106. 3,
  107. padding=1)
  108. def forward_single(self, x):
  109. cls_feat = x
  110. reg_feat = x
  111. for reg_layer in self.reg_convs:
  112. reg_feat = reg_layer(reg_feat)
  113. bbox_pred = self.conv_reg(reg_feat)
  114. if self.with_deform:
  115. cls_feat = self.feature_adaption(cls_feat, bbox_pred.exp())
  116. for cls_layer in self.cls_convs:
  117. cls_feat = cls_layer(cls_feat)
  118. cls_score = self.conv_cls(cls_feat)
  119. return cls_score, bbox_pred
  120. def loss(self,
  121. cls_scores,
  122. bbox_preds,
  123. gt_bbox_list,
  124. gt_label_list,
  125. img_metas,
  126. gt_bboxes_ignore=None):
  127. assert len(cls_scores) == len(bbox_preds)
  128. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  129. points = self.prior_generator.grid_priors(
  130. featmap_sizes,
  131. dtype=bbox_preds[0].dtype,
  132. device=bbox_preds[0].device)
  133. num_imgs = cls_scores[0].size(0)
  134. flatten_cls_scores = [
  135. cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
  136. for cls_score in cls_scores
  137. ]
  138. flatten_bbox_preds = [
  139. bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
  140. for bbox_pred in bbox_preds
  141. ]
  142. flatten_cls_scores = torch.cat(flatten_cls_scores)
  143. flatten_bbox_preds = torch.cat(flatten_bbox_preds)
  144. flatten_labels, flatten_bbox_targets = self.get_targets(
  145. gt_bbox_list, gt_label_list, featmap_sizes, points)
  146. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  147. pos_inds = ((flatten_labels >= 0)
  148. & (flatten_labels < self.num_classes)).nonzero().view(-1)
  149. num_pos = len(pos_inds)
  150. loss_cls = self.loss_cls(
  151. flatten_cls_scores, flatten_labels, avg_factor=num_pos + num_imgs)
  152. if num_pos > 0:
  153. pos_bbox_preds = flatten_bbox_preds[pos_inds]
  154. pos_bbox_targets = flatten_bbox_targets[pos_inds]
  155. pos_weights = pos_bbox_targets.new_zeros(
  156. pos_bbox_targets.size()) + 1.0
  157. loss_bbox = self.loss_bbox(
  158. pos_bbox_preds,
  159. pos_bbox_targets,
  160. pos_weights,
  161. avg_factor=num_pos)
  162. else:
  163. loss_bbox = torch.tensor(
  164. 0,
  165. dtype=flatten_bbox_preds.dtype,
  166. device=flatten_bbox_preds.device)
  167. return dict(loss_cls=loss_cls, loss_bbox=loss_bbox)
  168. def get_targets(self, gt_bbox_list, gt_label_list, featmap_sizes, points):
  169. label_list, bbox_target_list = multi_apply(
  170. self._get_target_single,
  171. gt_bbox_list,
  172. gt_label_list,
  173. featmap_size_list=featmap_sizes,
  174. point_list=points)
  175. flatten_labels = [
  176. torch.cat([
  177. labels_level_img.flatten() for labels_level_img in labels_level
  178. ]) for labels_level in zip(*label_list)
  179. ]
  180. flatten_bbox_targets = [
  181. torch.cat([
  182. bbox_targets_level_img.reshape(-1, 4)
  183. for bbox_targets_level_img in bbox_targets_level
  184. ]) for bbox_targets_level in zip(*bbox_target_list)
  185. ]
  186. flatten_labels = torch.cat(flatten_labels)
  187. flatten_bbox_targets = torch.cat(flatten_bbox_targets)
  188. return flatten_labels, flatten_bbox_targets
  189. def _get_target_single(self,
  190. gt_bboxes_raw,
  191. gt_labels_raw,
  192. featmap_size_list=None,
  193. point_list=None):
  194. gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) *
  195. (gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
  196. label_list = []
  197. bbox_target_list = []
  198. # for each pyramid, find the cls and box target
  199. for base_len, (lower_bound, upper_bound), stride, featmap_size, \
  200. points in zip(self.base_edge_list, self.scale_ranges,
  201. self.strides, featmap_size_list, point_list):
  202. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  203. points = points.view(*featmap_size, 2)
  204. x, y = points[..., 0], points[..., 1]
  205. labels = gt_labels_raw.new_zeros(featmap_size) + self.num_classes
  206. bbox_targets = gt_bboxes_raw.new(featmap_size[0], featmap_size[1],
  207. 4) + 1
  208. # scale assignment
  209. hit_indices = ((gt_areas >= lower_bound) &
  210. (gt_areas <= upper_bound)).nonzero().flatten()
  211. if len(hit_indices) == 0:
  212. label_list.append(labels)
  213. bbox_target_list.append(torch.log(bbox_targets))
  214. continue
  215. _, hit_index_order = torch.sort(-gt_areas[hit_indices])
  216. hit_indices = hit_indices[hit_index_order]
  217. gt_bboxes = gt_bboxes_raw[hit_indices, :] / stride
  218. gt_labels = gt_labels_raw[hit_indices]
  219. half_w = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0])
  220. half_h = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1])
  221. # valid fovea area: left, right, top, down
  222. pos_left = torch.ceil(
  223. gt_bboxes[:, 0] + (1 - self.sigma) * half_w - 0.5).long(). \
  224. clamp(0, featmap_size[1] - 1)
  225. pos_right = torch.floor(
  226. gt_bboxes[:, 0] + (1 + self.sigma) * half_w - 0.5).long(). \
  227. clamp(0, featmap_size[1] - 1)
  228. pos_top = torch.ceil(
  229. gt_bboxes[:, 1] + (1 - self.sigma) * half_h - 0.5).long(). \
  230. clamp(0, featmap_size[0] - 1)
  231. pos_down = torch.floor(
  232. gt_bboxes[:, 1] + (1 + self.sigma) * half_h - 0.5).long(). \
  233. clamp(0, featmap_size[0] - 1)
  234. for px1, py1, px2, py2, label, (gt_x1, gt_y1, gt_x2, gt_y2) in \
  235. zip(pos_left, pos_top, pos_right, pos_down, gt_labels,
  236. gt_bboxes_raw[hit_indices, :]):
  237. labels[py1:py2 + 1, px1:px2 + 1] = label
  238. bbox_targets[py1:py2 + 1, px1:px2 + 1, 0] = \
  239. (x[py1:py2 + 1, px1:px2 + 1] - gt_x1) / base_len
  240. bbox_targets[py1:py2 + 1, px1:px2 + 1, 1] = \
  241. (y[py1:py2 + 1, px1:px2 + 1] - gt_y1) / base_len
  242. bbox_targets[py1:py2 + 1, px1:px2 + 1, 2] = \
  243. (gt_x2 - x[py1:py2 + 1, px1:px2 + 1]) / base_len
  244. bbox_targets[py1:py2 + 1, px1:px2 + 1, 3] = \
  245. (gt_y2 - y[py1:py2 + 1, px1:px2 + 1]) / base_len
  246. bbox_targets = bbox_targets.clamp(min=1. / 16, max=16.)
  247. label_list.append(labels)
  248. bbox_target_list.append(torch.log(bbox_targets))
  249. return label_list, bbox_target_list
  250. # Same as base_dense_head/_get_bboxes_single except self._bbox_decode
  251. def _get_bboxes_single(self,
  252. cls_score_list,
  253. bbox_pred_list,
  254. score_factor_list,
  255. mlvl_priors,
  256. img_meta,
  257. cfg,
  258. rescale=False,
  259. with_nms=True,
  260. **kwargs):
  261. """Transform outputs of a single image into bbox predictions.
  262. Args:
  263. cls_score_list (list[Tensor]): Box scores from all scale
  264. levels of a single image, each item has shape
  265. (num_priors * num_classes, H, W).
  266. bbox_pred_list (list[Tensor]): Box energies / deltas from
  267. all scale levels of a single image, each item has shape
  268. (num_priors * 4, H, W).
  269. score_factor_list (list[Tensor]): Score factor from all scale
  270. levels of a single image. Fovea head does not need this value.
  271. mlvl_priors (list[Tensor]): Each element in the list is
  272. the priors of a single level in feature pyramid, has shape
  273. (num_priors, 2).
  274. img_meta (dict): Image meta info.
  275. cfg (mmcv.Config): Test / postprocessing configuration,
  276. if None, test_cfg would be used.
  277. rescale (bool): If True, return boxes in original image space.
  278. Default: False.
  279. with_nms (bool): If True, do nms before return boxes.
  280. Default: True.
  281. Returns:
  282. tuple[Tensor]: Results of detected bboxes and labels. If with_nms
  283. is False and mlvl_score_factor is None, return mlvl_bboxes and
  284. mlvl_scores, else return mlvl_bboxes, mlvl_scores and
  285. mlvl_score_factor. Usually with_nms is False is used for aug
  286. test. If with_nms is True, then return the following format
  287. - det_bboxes (Tensor): Predicted bboxes with shape \
  288. [num_bboxes, 5], where the first 4 columns are bounding \
  289. box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
  290. column are scores between 0 and 1.
  291. - det_labels (Tensor): Predicted labels of the corresponding \
  292. box with shape [num_bboxes].
  293. """
  294. cfg = self.test_cfg if cfg is None else cfg
  295. assert len(cls_score_list) == len(bbox_pred_list)
  296. img_shape = img_meta['img_shape']
  297. nms_pre = cfg.get('nms_pre', -1)
  298. mlvl_bboxes = []
  299. mlvl_scores = []
  300. mlvl_labels = []
  301. for level_idx, (cls_score, bbox_pred, stride, base_len, priors) in \
  302. enumerate(zip(cls_score_list, bbox_pred_list, self.strides,
  303. self.base_edge_list, mlvl_priors)):
  304. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  305. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
  306. scores = cls_score.permute(1, 2, 0).reshape(
  307. -1, self.cls_out_channels).sigmoid()
  308. # After https://github.com/open-mmlab/mmdetection/pull/6268/,
  309. # this operation keeps fewer bboxes under the same `nms_pre`.
  310. # There is no difference in performance for most models. If you
  311. # find a slight drop in performance, you can set a larger
  312. # `nms_pre` than before.
  313. results = filter_scores_and_topk(
  314. scores, cfg.score_thr, nms_pre,
  315. dict(bbox_pred=bbox_pred, priors=priors))
  316. scores, labels, _, filtered_results = results
  317. bbox_pred = filtered_results['bbox_pred']
  318. priors = filtered_results['priors']
  319. bboxes = self._bbox_decode(priors, bbox_pred, base_len, img_shape)
  320. mlvl_bboxes.append(bboxes)
  321. mlvl_scores.append(scores)
  322. mlvl_labels.append(labels)
  323. return self._bbox_post_process(mlvl_scores, mlvl_labels, mlvl_bboxes,
  324. img_meta['scale_factor'], cfg, rescale,
  325. with_nms)
  326. def _bbox_decode(self, priors, bbox_pred, base_len, max_shape):
  327. bbox_pred = bbox_pred.exp()
  328. y = priors[:, 1]
  329. x = priors[:, 0]
  330. x1 = (x - base_len * bbox_pred[:, 0]). \
  331. clamp(min=0, max=max_shape[1] - 1)
  332. y1 = (y - base_len * bbox_pred[:, 1]). \
  333. clamp(min=0, max=max_shape[0] - 1)
  334. x2 = (x + base_len * bbox_pred[:, 2]). \
  335. clamp(min=0, max=max_shape[1] - 1)
  336. y2 = (y + base_len * bbox_pred[:, 3]). \
  337. clamp(min=0, max=max_shape[0] - 1)
  338. decoded_bboxes = torch.stack([x1, y1, x2, y2], -1)
  339. return decoded_bboxes
  340. def _get_points_single(self, *args, **kwargs):
  341. """Get points according to feature map size.
  342. This function will be deprecated soon.
  343. """
  344. warnings.warn(
  345. '`_get_points_single` in `FoveaHead` will be '
  346. 'deprecated soon, we support a multi level point generator now'
  347. 'you can get points of a single level feature map '
  348. 'with `self.prior_generator.single_level_grid_priors` ')
  349. y, x = super()._get_points_single(*args, **kwargs)
  350. return y + 0.5, x + 0.5

No Description

Contributors (3)