You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sabl_retina_head.py 27 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from mmcv.cnn import ConvModule
  7. from mmcv.runner import force_fp32
  8. from mmdet.core import (build_assigner, build_bbox_coder,
  9. build_prior_generator, build_sampler, images_to_levels,
  10. multi_apply, unmap)
  11. from mmdet.core.utils import filter_scores_and_topk
  12. from ..builder import HEADS, build_loss
  13. from .base_dense_head import BaseDenseHead
  14. from .dense_test_mixins import BBoxTestMixin
  15. from .guided_anchor_head import GuidedAnchorHead
  16. @HEADS.register_module()
  17. class SABLRetinaHead(BaseDenseHead, BBoxTestMixin):
  18. """Side-Aware Boundary Localization (SABL) for RetinaNet.
  19. The anchor generation, assigning and sampling in SABLRetinaHead
  20. are the same as GuidedAnchorHead for guided anchoring.
  21. Please refer to https://arxiv.org/abs/1912.04260 for more details.
  22. Args:
  23. num_classes (int): Number of classes.
  24. in_channels (int): Number of channels in the input feature map.
  25. stacked_convs (int): Number of Convs for classification \
  26. and regression branches. Defaults to 4.
  27. feat_channels (int): Number of hidden channels. \
  28. Defaults to 256.
  29. approx_anchor_generator (dict): Config dict for approx generator.
  30. square_anchor_generator (dict): Config dict for square generator.
  31. conv_cfg (dict): Config dict for ConvModule. Defaults to None.
  32. norm_cfg (dict): Config dict for Norm Layer. Defaults to None.
  33. bbox_coder (dict): Config dict for bbox coder.
  34. reg_decoded_bbox (bool): If true, the regression loss would be
  35. applied directly on decoded bounding boxes, converting both
  36. the predicted boxes and regression targets to absolute
  37. coordinates format. Default False. It should be `True` when
  38. using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
  39. train_cfg (dict): Training config of SABLRetinaHead.
  40. test_cfg (dict): Testing config of SABLRetinaHead.
  41. loss_cls (dict): Config of classification loss.
  42. loss_bbox_cls (dict): Config of classification loss for bbox branch.
  43. loss_bbox_reg (dict): Config of regression loss for bbox branch.
  44. init_cfg (dict or list[dict], optional): Initialization config dict.
  45. """
  46. def __init__(self,
  47. num_classes,
  48. in_channels,
  49. stacked_convs=4,
  50. feat_channels=256,
  51. approx_anchor_generator=dict(
  52. type='AnchorGenerator',
  53. octave_base_scale=4,
  54. scales_per_octave=3,
  55. ratios=[0.5, 1.0, 2.0],
  56. strides=[8, 16, 32, 64, 128]),
  57. square_anchor_generator=dict(
  58. type='AnchorGenerator',
  59. ratios=[1.0],
  60. scales=[4],
  61. strides=[8, 16, 32, 64, 128]),
  62. conv_cfg=None,
  63. norm_cfg=None,
  64. bbox_coder=dict(
  65. type='BucketingBBoxCoder',
  66. num_buckets=14,
  67. scale_factor=3.0),
  68. reg_decoded_bbox=False,
  69. train_cfg=None,
  70. test_cfg=None,
  71. loss_cls=dict(
  72. type='FocalLoss',
  73. use_sigmoid=True,
  74. gamma=2.0,
  75. alpha=0.25,
  76. loss_weight=1.0),
  77. loss_bbox_cls=dict(
  78. type='CrossEntropyLoss',
  79. use_sigmoid=True,
  80. loss_weight=1.5),
  81. loss_bbox_reg=dict(
  82. type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5),
  83. init_cfg=dict(
  84. type='Normal',
  85. layer='Conv2d',
  86. std=0.01,
  87. override=dict(
  88. type='Normal',
  89. name='retina_cls',
  90. std=0.01,
  91. bias_prob=0.01))):
  92. super(SABLRetinaHead, self).__init__(init_cfg)
  93. self.in_channels = in_channels
  94. self.num_classes = num_classes
  95. self.feat_channels = feat_channels
  96. self.num_buckets = bbox_coder['num_buckets']
  97. self.side_num = int(np.ceil(self.num_buckets / 2))
  98. assert (approx_anchor_generator['octave_base_scale'] ==
  99. square_anchor_generator['scales'][0])
  100. assert (approx_anchor_generator['strides'] ==
  101. square_anchor_generator['strides'])
  102. self.approx_anchor_generator = build_prior_generator(
  103. approx_anchor_generator)
  104. self.square_anchor_generator = build_prior_generator(
  105. square_anchor_generator)
  106. self.approxs_per_octave = (
  107. self.approx_anchor_generator.num_base_priors[0])
  108. # one anchor per location
  109. self.num_base_priors = self.square_anchor_generator.num_base_priors[0]
  110. self.stacked_convs = stacked_convs
  111. self.conv_cfg = conv_cfg
  112. self.norm_cfg = norm_cfg
  113. self.reg_decoded_bbox = reg_decoded_bbox
  114. self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
  115. self.sampling = loss_cls['type'] not in [
  116. 'FocalLoss', 'GHMC', 'QualityFocalLoss'
  117. ]
  118. if self.use_sigmoid_cls:
  119. self.cls_out_channels = num_classes
  120. else:
  121. self.cls_out_channels = num_classes + 1
  122. self.bbox_coder = build_bbox_coder(bbox_coder)
  123. self.loss_cls = build_loss(loss_cls)
  124. self.loss_bbox_cls = build_loss(loss_bbox_cls)
  125. self.loss_bbox_reg = build_loss(loss_bbox_reg)
  126. self.train_cfg = train_cfg
  127. self.test_cfg = test_cfg
  128. if self.train_cfg:
  129. self.assigner = build_assigner(self.train_cfg.assigner)
  130. # use PseudoSampler when sampling is False
  131. if self.sampling and hasattr(self.train_cfg, 'sampler'):
  132. sampler_cfg = self.train_cfg.sampler
  133. else:
  134. sampler_cfg = dict(type='PseudoSampler')
  135. self.sampler = build_sampler(sampler_cfg, context=self)
  136. self.fp16_enabled = False
  137. self._init_layers()
  138. @property
  139. def num_anchors(self):
  140. warnings.warn('DeprecationWarning: `num_anchors` is deprecated, '
  141. 'please use "num_base_priors" instead')
  142. return self.square_anchor_generator.num_base_priors[0]
  143. def _init_layers(self):
  144. self.relu = nn.ReLU(inplace=True)
  145. self.cls_convs = nn.ModuleList()
  146. self.reg_convs = nn.ModuleList()
  147. for i in range(self.stacked_convs):
  148. chn = self.in_channels if i == 0 else self.feat_channels
  149. self.cls_convs.append(
  150. ConvModule(
  151. chn,
  152. self.feat_channels,
  153. 3,
  154. stride=1,
  155. padding=1,
  156. conv_cfg=self.conv_cfg,
  157. norm_cfg=self.norm_cfg))
  158. self.reg_convs.append(
  159. ConvModule(
  160. chn,
  161. self.feat_channels,
  162. 3,
  163. stride=1,
  164. padding=1,
  165. conv_cfg=self.conv_cfg,
  166. norm_cfg=self.norm_cfg))
  167. self.retina_cls = nn.Conv2d(
  168. self.feat_channels, self.cls_out_channels, 3, padding=1)
  169. self.retina_bbox_reg = nn.Conv2d(
  170. self.feat_channels, self.side_num * 4, 3, padding=1)
  171. self.retina_bbox_cls = nn.Conv2d(
  172. self.feat_channels, self.side_num * 4, 3, padding=1)
  173. def forward_single(self, x):
  174. cls_feat = x
  175. reg_feat = x
  176. for cls_conv in self.cls_convs:
  177. cls_feat = cls_conv(cls_feat)
  178. for reg_conv in self.reg_convs:
  179. reg_feat = reg_conv(reg_feat)
  180. cls_score = self.retina_cls(cls_feat)
  181. bbox_cls_pred = self.retina_bbox_cls(reg_feat)
  182. bbox_reg_pred = self.retina_bbox_reg(reg_feat)
  183. bbox_pred = (bbox_cls_pred, bbox_reg_pred)
  184. return cls_score, bbox_pred
  185. def forward(self, feats):
  186. return multi_apply(self.forward_single, feats)
  187. def get_anchors(self, featmap_sizes, img_metas, device='cuda'):
  188. """Get squares according to feature map sizes and guided anchors.
  189. Args:
  190. featmap_sizes (list[tuple]): Multi-level feature map sizes.
  191. img_metas (list[dict]): Image meta info.
  192. device (torch.device | str): device for returned tensors
  193. Returns:
  194. tuple: square approxs of each image
  195. """
  196. num_imgs = len(img_metas)
  197. # since feature map sizes of all images are the same, we only compute
  198. # squares for one time
  199. multi_level_squares = self.square_anchor_generator.grid_priors(
  200. featmap_sizes, device=device)
  201. squares_list = [multi_level_squares for _ in range(num_imgs)]
  202. return squares_list
  203. def get_target(self,
  204. approx_list,
  205. inside_flag_list,
  206. square_list,
  207. gt_bboxes_list,
  208. img_metas,
  209. gt_bboxes_ignore_list=None,
  210. gt_labels_list=None,
  211. label_channels=None,
  212. sampling=True,
  213. unmap_outputs=True):
  214. """Compute bucketing targets.
  215. Args:
  216. approx_list (list[list]): Multi level approxs of each image.
  217. inside_flag_list (list[list]): Multi level inside flags of each
  218. image.
  219. square_list (list[list]): Multi level squares of each image.
  220. gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
  221. img_metas (list[dict]): Meta info of each image.
  222. gt_bboxes_ignore_list (list[Tensor]): ignore list of gt bboxes.
  223. gt_bboxes_list (list[Tensor]): Gt bboxes of each image.
  224. label_channels (int): Channel of label.
  225. sampling (bool): Sample Anchors or not.
  226. unmap_outputs (bool): unmap outputs or not.
  227. Returns:
  228. tuple: Returns a tuple containing learning targets.
  229. - labels_list (list[Tensor]): Labels of each level.
  230. - label_weights_list (list[Tensor]): Label weights of each \
  231. level.
  232. - bbox_cls_targets_list (list[Tensor]): BBox cls targets of \
  233. each level.
  234. - bbox_cls_weights_list (list[Tensor]): BBox cls weights of \
  235. each level.
  236. - bbox_reg_targets_list (list[Tensor]): BBox reg targets of \
  237. each level.
  238. - bbox_reg_weights_list (list[Tensor]): BBox reg weights of \
  239. each level.
  240. - num_total_pos (int): Number of positive samples in all \
  241. images.
  242. - num_total_neg (int): Number of negative samples in all \
  243. images.
  244. """
  245. num_imgs = len(img_metas)
  246. assert len(approx_list) == len(inside_flag_list) == len(
  247. square_list) == num_imgs
  248. # anchor number of multi levels
  249. num_level_squares = [squares.size(0) for squares in square_list[0]]
  250. # concat all level anchors and flags to a single tensor
  251. inside_flag_flat_list = []
  252. approx_flat_list = []
  253. square_flat_list = []
  254. for i in range(num_imgs):
  255. assert len(square_list[i]) == len(inside_flag_list[i])
  256. inside_flag_flat_list.append(torch.cat(inside_flag_list[i]))
  257. approx_flat_list.append(torch.cat(approx_list[i]))
  258. square_flat_list.append(torch.cat(square_list[i]))
  259. # compute targets for each image
  260. if gt_bboxes_ignore_list is None:
  261. gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
  262. if gt_labels_list is None:
  263. gt_labels_list = [None for _ in range(num_imgs)]
  264. (all_labels, all_label_weights, all_bbox_cls_targets,
  265. all_bbox_cls_weights, all_bbox_reg_targets, all_bbox_reg_weights,
  266. pos_inds_list, neg_inds_list) = multi_apply(
  267. self._get_target_single,
  268. approx_flat_list,
  269. inside_flag_flat_list,
  270. square_flat_list,
  271. gt_bboxes_list,
  272. gt_bboxes_ignore_list,
  273. gt_labels_list,
  274. img_metas,
  275. label_channels=label_channels,
  276. sampling=sampling,
  277. unmap_outputs=unmap_outputs)
  278. # no valid anchors
  279. if any([labels is None for labels in all_labels]):
  280. return None
  281. # sampled anchors of all images
  282. num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
  283. num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
  284. # split targets to a list w.r.t. multiple levels
  285. labels_list = images_to_levels(all_labels, num_level_squares)
  286. label_weights_list = images_to_levels(all_label_weights,
  287. num_level_squares)
  288. bbox_cls_targets_list = images_to_levels(all_bbox_cls_targets,
  289. num_level_squares)
  290. bbox_cls_weights_list = images_to_levels(all_bbox_cls_weights,
  291. num_level_squares)
  292. bbox_reg_targets_list = images_to_levels(all_bbox_reg_targets,
  293. num_level_squares)
  294. bbox_reg_weights_list = images_to_levels(all_bbox_reg_weights,
  295. num_level_squares)
  296. return (labels_list, label_weights_list, bbox_cls_targets_list,
  297. bbox_cls_weights_list, bbox_reg_targets_list,
  298. bbox_reg_weights_list, num_total_pos, num_total_neg)
  299. def _get_target_single(self,
  300. flat_approxs,
  301. inside_flags,
  302. flat_squares,
  303. gt_bboxes,
  304. gt_bboxes_ignore,
  305. gt_labels,
  306. img_meta,
  307. label_channels=None,
  308. sampling=True,
  309. unmap_outputs=True):
  310. """Compute regression and classification targets for anchors in a
  311. single image.
  312. Args:
  313. flat_approxs (Tensor): flat approxs of a single image,
  314. shape (n, 4)
  315. inside_flags (Tensor): inside flags of a single image,
  316. shape (n, ).
  317. flat_squares (Tensor): flat squares of a single image,
  318. shape (approxs_per_octave * n, 4)
  319. gt_bboxes (Tensor): Ground truth bboxes of a single image, \
  320. shape (num_gts, 4).
  321. gt_bboxes_ignore (Tensor): Ground truth bboxes to be
  322. ignored, shape (num_ignored_gts, 4).
  323. gt_labels (Tensor): Ground truth labels of each box,
  324. shape (num_gts,).
  325. img_meta (dict): Meta info of the image.
  326. label_channels (int): Channel of label.
  327. sampling (bool): Sample Anchors or not.
  328. unmap_outputs (bool): unmap outputs or not.
  329. Returns:
  330. tuple:
  331. - labels_list (Tensor): Labels in a single image
  332. - label_weights (Tensor): Label weights in a single image
  333. - bbox_cls_targets (Tensor): BBox cls targets in a single image
  334. - bbox_cls_weights (Tensor): BBox cls weights in a single image
  335. - bbox_reg_targets (Tensor): BBox reg targets in a single image
  336. - bbox_reg_weights (Tensor): BBox reg weights in a single image
  337. - num_total_pos (int): Number of positive samples \
  338. in a single image
  339. - num_total_neg (int): Number of negative samples \
  340. in a single image
  341. """
  342. if not inside_flags.any():
  343. return (None, ) * 8
  344. # assign gt and sample anchors
  345. expand_inside_flags = inside_flags[:, None].expand(
  346. -1, self.approxs_per_octave).reshape(-1)
  347. approxs = flat_approxs[expand_inside_flags, :]
  348. squares = flat_squares[inside_flags, :]
  349. assign_result = self.assigner.assign(approxs, squares,
  350. self.approxs_per_octave,
  351. gt_bboxes, gt_bboxes_ignore)
  352. sampling_result = self.sampler.sample(assign_result, squares,
  353. gt_bboxes)
  354. num_valid_squares = squares.shape[0]
  355. bbox_cls_targets = squares.new_zeros(
  356. (num_valid_squares, self.side_num * 4))
  357. bbox_cls_weights = squares.new_zeros(
  358. (num_valid_squares, self.side_num * 4))
  359. bbox_reg_targets = squares.new_zeros(
  360. (num_valid_squares, self.side_num * 4))
  361. bbox_reg_weights = squares.new_zeros(
  362. (num_valid_squares, self.side_num * 4))
  363. labels = squares.new_full((num_valid_squares, ),
  364. self.num_classes,
  365. dtype=torch.long)
  366. label_weights = squares.new_zeros(num_valid_squares, dtype=torch.float)
  367. pos_inds = sampling_result.pos_inds
  368. neg_inds = sampling_result.neg_inds
  369. if len(pos_inds) > 0:
  370. (pos_bbox_reg_targets, pos_bbox_reg_weights, pos_bbox_cls_targets,
  371. pos_bbox_cls_weights) = self.bbox_coder.encode(
  372. sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
  373. bbox_cls_targets[pos_inds, :] = pos_bbox_cls_targets
  374. bbox_reg_targets[pos_inds, :] = pos_bbox_reg_targets
  375. bbox_cls_weights[pos_inds, :] = pos_bbox_cls_weights
  376. bbox_reg_weights[pos_inds, :] = pos_bbox_reg_weights
  377. if gt_labels is None:
  378. # Only rpn gives gt_labels as None
  379. # Foreground is the first class
  380. labels[pos_inds] = 0
  381. else:
  382. labels[pos_inds] = gt_labels[
  383. sampling_result.pos_assigned_gt_inds]
  384. if self.train_cfg.pos_weight <= 0:
  385. label_weights[pos_inds] = 1.0
  386. else:
  387. label_weights[pos_inds] = self.train_cfg.pos_weight
  388. if len(neg_inds) > 0:
  389. label_weights[neg_inds] = 1.0
  390. # map up to original set of anchors
  391. if unmap_outputs:
  392. num_total_anchors = flat_squares.size(0)
  393. labels = unmap(
  394. labels, num_total_anchors, inside_flags, fill=self.num_classes)
  395. label_weights = unmap(label_weights, num_total_anchors,
  396. inside_flags)
  397. bbox_cls_targets = unmap(bbox_cls_targets, num_total_anchors,
  398. inside_flags)
  399. bbox_cls_weights = unmap(bbox_cls_weights, num_total_anchors,
  400. inside_flags)
  401. bbox_reg_targets = unmap(bbox_reg_targets, num_total_anchors,
  402. inside_flags)
  403. bbox_reg_weights = unmap(bbox_reg_weights, num_total_anchors,
  404. inside_flags)
  405. return (labels, label_weights, bbox_cls_targets, bbox_cls_weights,
  406. bbox_reg_targets, bbox_reg_weights, pos_inds, neg_inds)
  407. def loss_single(self, cls_score, bbox_pred, labels, label_weights,
  408. bbox_cls_targets, bbox_cls_weights, bbox_reg_targets,
  409. bbox_reg_weights, num_total_samples):
  410. # classification loss
  411. labels = labels.reshape(-1)
  412. label_weights = label_weights.reshape(-1)
  413. cls_score = cls_score.permute(0, 2, 3,
  414. 1).reshape(-1, self.cls_out_channels)
  415. loss_cls = self.loss_cls(
  416. cls_score, labels, label_weights, avg_factor=num_total_samples)
  417. # regression loss
  418. bbox_cls_targets = bbox_cls_targets.reshape(-1, self.side_num * 4)
  419. bbox_cls_weights = bbox_cls_weights.reshape(-1, self.side_num * 4)
  420. bbox_reg_targets = bbox_reg_targets.reshape(-1, self.side_num * 4)
  421. bbox_reg_weights = bbox_reg_weights.reshape(-1, self.side_num * 4)
  422. (bbox_cls_pred, bbox_reg_pred) = bbox_pred
  423. bbox_cls_pred = bbox_cls_pred.permute(0, 2, 3, 1).reshape(
  424. -1, self.side_num * 4)
  425. bbox_reg_pred = bbox_reg_pred.permute(0, 2, 3, 1).reshape(
  426. -1, self.side_num * 4)
  427. loss_bbox_cls = self.loss_bbox_cls(
  428. bbox_cls_pred,
  429. bbox_cls_targets.long(),
  430. bbox_cls_weights,
  431. avg_factor=num_total_samples * 4 * self.side_num)
  432. loss_bbox_reg = self.loss_bbox_reg(
  433. bbox_reg_pred,
  434. bbox_reg_targets,
  435. bbox_reg_weights,
  436. avg_factor=num_total_samples * 4 * self.bbox_coder.offset_topk)
  437. return loss_cls, loss_bbox_cls, loss_bbox_reg
  438. @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
  439. def loss(self,
  440. cls_scores,
  441. bbox_preds,
  442. gt_bboxes,
  443. gt_labels,
  444. img_metas,
  445. gt_bboxes_ignore=None):
  446. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  447. assert len(featmap_sizes) == self.approx_anchor_generator.num_levels
  448. device = cls_scores[0].device
  449. # get sampled approxes
  450. approxs_list, inside_flag_list = GuidedAnchorHead.get_sampled_approxs(
  451. self, featmap_sizes, img_metas, device=device)
  452. square_list = self.get_anchors(featmap_sizes, img_metas, device=device)
  453. label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
  454. cls_reg_targets = self.get_target(
  455. approxs_list,
  456. inside_flag_list,
  457. square_list,
  458. gt_bboxes,
  459. img_metas,
  460. gt_bboxes_ignore_list=gt_bboxes_ignore,
  461. gt_labels_list=gt_labels,
  462. label_channels=label_channels,
  463. sampling=self.sampling)
  464. if cls_reg_targets is None:
  465. return None
  466. (labels_list, label_weights_list, bbox_cls_targets_list,
  467. bbox_cls_weights_list, bbox_reg_targets_list, bbox_reg_weights_list,
  468. num_total_pos, num_total_neg) = cls_reg_targets
  469. num_total_samples = (
  470. num_total_pos + num_total_neg if self.sampling else num_total_pos)
  471. losses_cls, losses_bbox_cls, losses_bbox_reg = multi_apply(
  472. self.loss_single,
  473. cls_scores,
  474. bbox_preds,
  475. labels_list,
  476. label_weights_list,
  477. bbox_cls_targets_list,
  478. bbox_cls_weights_list,
  479. bbox_reg_targets_list,
  480. bbox_reg_weights_list,
  481. num_total_samples=num_total_samples)
  482. return dict(
  483. loss_cls=losses_cls,
  484. loss_bbox_cls=losses_bbox_cls,
  485. loss_bbox_reg=losses_bbox_reg)
  486. @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
  487. def get_bboxes(self,
  488. cls_scores,
  489. bbox_preds,
  490. img_metas,
  491. cfg=None,
  492. rescale=False):
  493. assert len(cls_scores) == len(bbox_preds)
  494. num_levels = len(cls_scores)
  495. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  496. device = cls_scores[0].device
  497. mlvl_anchors = self.get_anchors(
  498. featmap_sizes, img_metas, device=device)
  499. result_list = []
  500. for img_id in range(len(img_metas)):
  501. cls_score_list = [
  502. cls_scores[i][img_id].detach() for i in range(num_levels)
  503. ]
  504. bbox_cls_pred_list = [
  505. bbox_preds[i][0][img_id].detach() for i in range(num_levels)
  506. ]
  507. bbox_reg_pred_list = [
  508. bbox_preds[i][1][img_id].detach() for i in range(num_levels)
  509. ]
  510. img_shape = img_metas[img_id]['img_shape']
  511. scale_factor = img_metas[img_id]['scale_factor']
  512. proposals = self._get_bboxes_single(
  513. cls_score_list, bbox_cls_pred_list, bbox_reg_pred_list,
  514. mlvl_anchors[img_id], img_shape, scale_factor, cfg, rescale)
  515. result_list.append(proposals)
  516. return result_list
  517. def _get_bboxes_single(self,
  518. cls_scores,
  519. bbox_cls_preds,
  520. bbox_reg_preds,
  521. mlvl_anchors,
  522. img_shape,
  523. scale_factor,
  524. cfg,
  525. rescale=False):
  526. cfg = self.test_cfg if cfg is None else cfg
  527. nms_pre = cfg.get('nms_pre', -1)
  528. mlvl_bboxes = []
  529. mlvl_scores = []
  530. mlvl_confids = []
  531. mlvl_labels = []
  532. assert len(cls_scores) == len(bbox_cls_preds) == len(
  533. bbox_reg_preds) == len(mlvl_anchors)
  534. for cls_score, bbox_cls_pred, bbox_reg_pred, anchors in zip(
  535. cls_scores, bbox_cls_preds, bbox_reg_preds, mlvl_anchors):
  536. assert cls_score.size()[-2:] == bbox_cls_pred.size(
  537. )[-2:] == bbox_reg_pred.size()[-2::]
  538. cls_score = cls_score.permute(1, 2,
  539. 0).reshape(-1, self.cls_out_channels)
  540. if self.use_sigmoid_cls:
  541. scores = cls_score.sigmoid()
  542. else:
  543. scores = cls_score.softmax(-1)[:, :-1]
  544. bbox_cls_pred = bbox_cls_pred.permute(1, 2, 0).reshape(
  545. -1, self.side_num * 4)
  546. bbox_reg_pred = bbox_reg_pred.permute(1, 2, 0).reshape(
  547. -1, self.side_num * 4)
  548. # After https://github.com/open-mmlab/mmdetection/pull/6268/,
  549. # this operation keeps fewer bboxes under the same `nms_pre`.
  550. # There is no difference in performance for most models. If you
  551. # find a slight drop in performance, you can set a larger
  552. # `nms_pre` than before.
  553. results = filter_scores_and_topk(
  554. scores, cfg.score_thr, nms_pre,
  555. dict(
  556. anchors=anchors,
  557. bbox_cls_pred=bbox_cls_pred,
  558. bbox_reg_pred=bbox_reg_pred))
  559. scores, labels, _, filtered_results = results
  560. anchors = filtered_results['anchors']
  561. bbox_cls_pred = filtered_results['bbox_cls_pred']
  562. bbox_reg_pred = filtered_results['bbox_reg_pred']
  563. bbox_preds = [
  564. bbox_cls_pred.contiguous(),
  565. bbox_reg_pred.contiguous()
  566. ]
  567. bboxes, confids = self.bbox_coder.decode(
  568. anchors.contiguous(), bbox_preds, max_shape=img_shape)
  569. mlvl_bboxes.append(bboxes)
  570. mlvl_scores.append(scores)
  571. mlvl_confids.append(confids)
  572. mlvl_labels.append(labels)
  573. return self._bbox_post_process(mlvl_scores, mlvl_labels, mlvl_bboxes,
  574. scale_factor, cfg, rescale, True,
  575. mlvl_confids)

No Description

Contributors (3)