You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

paa_head.py 34 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. from mmcv.runner import force_fp32
  5. from mmdet.core import multi_apply, multiclass_nms
  6. from mmdet.core.bbox.iou_calculators import bbox_overlaps
  7. from mmdet.models import HEADS
  8. from mmdet.models.dense_heads import ATSSHead
  9. EPS = 1e-12
  10. try:
  11. import sklearn.mixture as skm
  12. except ImportError:
  13. skm = None
  14. def levels_to_images(mlvl_tensor):
  15. """Concat multi-level feature maps by image.
  16. [feature_level0, feature_level1...] -> [feature_image0, feature_image1...]
  17. Convert the shape of each element in mlvl_tensor from (N, C, H, W) to
  18. (N, H*W , C), then split the element to N elements with shape (H*W, C), and
  19. concat elements in same image of all level along first dimension.
  20. Args:
  21. mlvl_tensor (list[torch.Tensor]): list of Tensor which collect from
  22. corresponding level. Each element is of shape (N, C, H, W)
  23. Returns:
  24. list[torch.Tensor]: A list that contains N tensors and each tensor is
  25. of shape (num_elements, C)
  26. """
  27. batch_size = mlvl_tensor[0].size(0)
  28. batch_list = [[] for _ in range(batch_size)]
  29. channels = mlvl_tensor[0].size(1)
  30. for t in mlvl_tensor:
  31. t = t.permute(0, 2, 3, 1)
  32. t = t.view(batch_size, -1, channels).contiguous()
  33. for img in range(batch_size):
  34. batch_list[img].append(t[img])
  35. return [torch.cat(item, 0) for item in batch_list]
  36. @HEADS.register_module()
  37. class PAAHead(ATSSHead):
  38. """Head of PAAAssignment: Probabilistic Anchor Assignment with IoU
  39. Prediction for Object Detection.
  40. Code is modified from the `official github repo
  41. <https://github.com/kkhoot/PAA/blob/master/paa_core
  42. /modeling/rpn/paa/loss.py>`_.
  43. More details can be found in the `paper
  44. <https://arxiv.org/abs/2007.08103>`_ .
  45. Args:
  46. topk (int): Select topk samples with smallest loss in
  47. each level.
  48. score_voting (bool): Whether to use score voting in post-process.
  49. covariance_type : String describing the type of covariance parameters
  50. to be used in :class:`sklearn.mixture.GaussianMixture`.
  51. It must be one of:
  52. - 'full': each component has its own general covariance matrix
  53. - 'tied': all components share the same general covariance matrix
  54. - 'diag': each component has its own diagonal covariance matrix
  55. - 'spherical': each component has its own single variance
  56. Default: 'diag'. From 'full' to 'spherical', the gmm fitting
  57. process is faster yet the performance could be influenced. For most
  58. cases, 'diag' should be a good choice.
  59. """
  60. def __init__(self,
  61. *args,
  62. topk=9,
  63. score_voting=True,
  64. covariance_type='diag',
  65. **kwargs):
  66. # topk used in paa reassign process
  67. self.topk = topk
  68. self.with_score_voting = score_voting
  69. self.covariance_type = covariance_type
  70. super(PAAHead, self).__init__(*args, **kwargs)
  71. @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'iou_preds'))
  72. def loss(self,
  73. cls_scores,
  74. bbox_preds,
  75. iou_preds,
  76. gt_bboxes,
  77. gt_labels,
  78. img_metas,
  79. gt_bboxes_ignore=None):
  80. """Compute losses of the head.
  81. Args:
  82. cls_scores (list[Tensor]): Box scores for each scale level
  83. Has shape (N, num_anchors * num_classes, H, W)
  84. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  85. level with shape (N, num_anchors * 4, H, W)
  86. iou_preds (list[Tensor]): iou_preds for each scale
  87. level with shape (N, num_anchors * 1, H, W)
  88. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  89. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  90. gt_labels (list[Tensor]): class indices corresponding to each box
  91. img_metas (list[dict]): Meta information of each image, e.g.,
  92. image size, scaling factor, etc.
  93. gt_bboxes_ignore (list[Tensor] | None): Specify which bounding
  94. boxes can be ignored when are computing the loss.
  95. Returns:
  96. dict[str, Tensor]: A dictionary of loss gmm_assignment.
  97. """
  98. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  99. assert len(featmap_sizes) == self.prior_generator.num_levels
  100. device = cls_scores[0].device
  101. anchor_list, valid_flag_list = self.get_anchors(
  102. featmap_sizes, img_metas, device=device)
  103. label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
  104. cls_reg_targets = self.get_targets(
  105. anchor_list,
  106. valid_flag_list,
  107. gt_bboxes,
  108. img_metas,
  109. gt_bboxes_ignore_list=gt_bboxes_ignore,
  110. gt_labels_list=gt_labels,
  111. label_channels=label_channels,
  112. )
  113. (labels, labels_weight, bboxes_target, bboxes_weight, pos_inds,
  114. pos_gt_index) = cls_reg_targets
  115. cls_scores = levels_to_images(cls_scores)
  116. cls_scores = [
  117. item.reshape(-1, self.cls_out_channels) for item in cls_scores
  118. ]
  119. bbox_preds = levels_to_images(bbox_preds)
  120. bbox_preds = [item.reshape(-1, 4) for item in bbox_preds]
  121. iou_preds = levels_to_images(iou_preds)
  122. iou_preds = [item.reshape(-1, 1) for item in iou_preds]
  123. pos_losses_list, = multi_apply(self.get_pos_loss, anchor_list,
  124. cls_scores, bbox_preds, labels,
  125. labels_weight, bboxes_target,
  126. bboxes_weight, pos_inds)
  127. with torch.no_grad():
  128. reassign_labels, reassign_label_weight, \
  129. reassign_bbox_weights, num_pos = multi_apply(
  130. self.paa_reassign,
  131. pos_losses_list,
  132. labels,
  133. labels_weight,
  134. bboxes_weight,
  135. pos_inds,
  136. pos_gt_index,
  137. anchor_list)
  138. num_pos = sum(num_pos)
  139. # convert all tensor list to a flatten tensor
  140. cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1))
  141. bbox_preds = torch.cat(bbox_preds, 0).view(-1, bbox_preds[0].size(-1))
  142. iou_preds = torch.cat(iou_preds, 0).view(-1, iou_preds[0].size(-1))
  143. labels = torch.cat(reassign_labels, 0).view(-1)
  144. flatten_anchors = torch.cat(
  145. [torch.cat(item, 0) for item in anchor_list])
  146. labels_weight = torch.cat(reassign_label_weight, 0).view(-1)
  147. bboxes_target = torch.cat(bboxes_target,
  148. 0).view(-1, bboxes_target[0].size(-1))
  149. pos_inds_flatten = ((labels >= 0)
  150. &
  151. (labels < self.num_classes)).nonzero().reshape(-1)
  152. losses_cls = self.loss_cls(
  153. cls_scores,
  154. labels,
  155. labels_weight,
  156. avg_factor=max(num_pos, len(img_metas))) # avoid num_pos=0
  157. if num_pos:
  158. pos_bbox_pred = self.bbox_coder.decode(
  159. flatten_anchors[pos_inds_flatten],
  160. bbox_preds[pos_inds_flatten])
  161. pos_bbox_target = bboxes_target[pos_inds_flatten]
  162. iou_target = bbox_overlaps(
  163. pos_bbox_pred.detach(), pos_bbox_target, is_aligned=True)
  164. losses_iou = self.loss_centerness(
  165. iou_preds[pos_inds_flatten],
  166. iou_target.unsqueeze(-1),
  167. avg_factor=num_pos)
  168. losses_bbox = self.loss_bbox(
  169. pos_bbox_pred,
  170. pos_bbox_target,
  171. iou_target.clamp(min=EPS),
  172. avg_factor=iou_target.sum())
  173. else:
  174. losses_iou = iou_preds.sum() * 0
  175. losses_bbox = bbox_preds.sum() * 0
  176. return dict(
  177. loss_cls=losses_cls, loss_bbox=losses_bbox, loss_iou=losses_iou)
  178. def get_pos_loss(self, anchors, cls_score, bbox_pred, label, label_weight,
  179. bbox_target, bbox_weight, pos_inds):
  180. """Calculate loss of all potential positive samples obtained from first
  181. match process.
  182. Args:
  183. anchors (list[Tensor]): Anchors of each scale.
  184. cls_score (Tensor): Box scores of single image with shape
  185. (num_anchors, num_classes)
  186. bbox_pred (Tensor): Box energies / deltas of single image
  187. with shape (num_anchors, 4)
  188. label (Tensor): classification target of each anchor with
  189. shape (num_anchors,)
  190. label_weight (Tensor): Classification loss weight of each
  191. anchor with shape (num_anchors).
  192. bbox_target (dict): Regression target of each anchor with
  193. shape (num_anchors, 4).
  194. bbox_weight (Tensor): Bbox weight of each anchor with shape
  195. (num_anchors, 4).
  196. pos_inds (Tensor): Index of all positive samples got from
  197. first assign process.
  198. Returns:
  199. Tensor: Losses of all positive samples in single image.
  200. """
  201. if not len(pos_inds):
  202. return cls_score.new([]),
  203. anchors_all_level = torch.cat(anchors, 0)
  204. pos_scores = cls_score[pos_inds]
  205. pos_bbox_pred = bbox_pred[pos_inds]
  206. pos_label = label[pos_inds]
  207. pos_label_weight = label_weight[pos_inds]
  208. pos_bbox_target = bbox_target[pos_inds]
  209. pos_bbox_weight = bbox_weight[pos_inds]
  210. pos_anchors = anchors_all_level[pos_inds]
  211. pos_bbox_pred = self.bbox_coder.decode(pos_anchors, pos_bbox_pred)
  212. # to keep loss dimension
  213. loss_cls = self.loss_cls(
  214. pos_scores,
  215. pos_label,
  216. pos_label_weight,
  217. avg_factor=self.loss_cls.loss_weight,
  218. reduction_override='none')
  219. loss_bbox = self.loss_bbox(
  220. pos_bbox_pred,
  221. pos_bbox_target,
  222. pos_bbox_weight,
  223. avg_factor=self.loss_cls.loss_weight,
  224. reduction_override='none')
  225. loss_cls = loss_cls.sum(-1)
  226. pos_loss = loss_bbox + loss_cls
  227. return pos_loss,
  228. def paa_reassign(self, pos_losses, label, label_weight, bbox_weight,
  229. pos_inds, pos_gt_inds, anchors):
  230. """Fit loss to GMM distribution and separate positive, ignore, negative
  231. samples again with GMM model.
  232. Args:
  233. pos_losses (Tensor): Losses of all positive samples in
  234. single image.
  235. label (Tensor): classification target of each anchor with
  236. shape (num_anchors,)
  237. label_weight (Tensor): Classification loss weight of each
  238. anchor with shape (num_anchors).
  239. bbox_weight (Tensor): Bbox weight of each anchor with shape
  240. (num_anchors, 4).
  241. pos_inds (Tensor): Index of all positive samples got from
  242. first assign process.
  243. pos_gt_inds (Tensor): Gt_index of all positive samples got
  244. from first assign process.
  245. anchors (list[Tensor]): Anchors of each scale.
  246. Returns:
  247. tuple: Usually returns a tuple containing learning targets.
  248. - label (Tensor): classification target of each anchor after
  249. paa assign, with shape (num_anchors,)
  250. - label_weight (Tensor): Classification loss weight of each
  251. anchor after paa assign, with shape (num_anchors).
  252. - bbox_weight (Tensor): Bbox weight of each anchor with shape
  253. (num_anchors, 4).
  254. - num_pos (int): The number of positive samples after paa
  255. assign.
  256. """
  257. if not len(pos_inds):
  258. return label, label_weight, bbox_weight, 0
  259. label = label.clone()
  260. label_weight = label_weight.clone()
  261. bbox_weight = bbox_weight.clone()
  262. num_gt = pos_gt_inds.max() + 1
  263. num_level = len(anchors)
  264. num_anchors_each_level = [item.size(0) for item in anchors]
  265. num_anchors_each_level.insert(0, 0)
  266. inds_level_interval = np.cumsum(num_anchors_each_level)
  267. pos_level_mask = []
  268. for i in range(num_level):
  269. mask = (pos_inds >= inds_level_interval[i]) & (
  270. pos_inds < inds_level_interval[i + 1])
  271. pos_level_mask.append(mask)
  272. pos_inds_after_paa = [label.new_tensor([])]
  273. ignore_inds_after_paa = [label.new_tensor([])]
  274. for gt_ind in range(num_gt):
  275. pos_inds_gmm = []
  276. pos_loss_gmm = []
  277. gt_mask = pos_gt_inds == gt_ind
  278. for level in range(num_level):
  279. level_mask = pos_level_mask[level]
  280. level_gt_mask = level_mask & gt_mask
  281. value, topk_inds = pos_losses[level_gt_mask].topk(
  282. min(level_gt_mask.sum(), self.topk), largest=False)
  283. pos_inds_gmm.append(pos_inds[level_gt_mask][topk_inds])
  284. pos_loss_gmm.append(value)
  285. pos_inds_gmm = torch.cat(pos_inds_gmm)
  286. pos_loss_gmm = torch.cat(pos_loss_gmm)
  287. # fix gmm need at least two sample
  288. if len(pos_inds_gmm) < 2:
  289. continue
  290. device = pos_inds_gmm.device
  291. pos_loss_gmm, sort_inds = pos_loss_gmm.sort()
  292. pos_inds_gmm = pos_inds_gmm[sort_inds]
  293. pos_loss_gmm = pos_loss_gmm.view(-1, 1).cpu().numpy()
  294. min_loss, max_loss = pos_loss_gmm.min(), pos_loss_gmm.max()
  295. means_init = np.array([min_loss, max_loss]).reshape(2, 1)
  296. weights_init = np.array([0.5, 0.5])
  297. precisions_init = np.array([1.0, 1.0]).reshape(2, 1, 1) # full
  298. if self.covariance_type == 'spherical':
  299. precisions_init = precisions_init.reshape(2)
  300. elif self.covariance_type == 'diag':
  301. precisions_init = precisions_init.reshape(2, 1)
  302. elif self.covariance_type == 'tied':
  303. precisions_init = np.array([[1.0]])
  304. if skm is None:
  305. raise ImportError('Please run "pip install sklearn" '
  306. 'to install sklearn first.')
  307. gmm = skm.GaussianMixture(
  308. 2,
  309. weights_init=weights_init,
  310. means_init=means_init,
  311. precisions_init=precisions_init,
  312. covariance_type=self.covariance_type)
  313. gmm.fit(pos_loss_gmm)
  314. gmm_assignment = gmm.predict(pos_loss_gmm)
  315. scores = gmm.score_samples(pos_loss_gmm)
  316. gmm_assignment = torch.from_numpy(gmm_assignment).to(device)
  317. scores = torch.from_numpy(scores).to(device)
  318. pos_inds_temp, ignore_inds_temp = self.gmm_separation_scheme(
  319. gmm_assignment, scores, pos_inds_gmm)
  320. pos_inds_after_paa.append(pos_inds_temp)
  321. ignore_inds_after_paa.append(ignore_inds_temp)
  322. pos_inds_after_paa = torch.cat(pos_inds_after_paa)
  323. ignore_inds_after_paa = torch.cat(ignore_inds_after_paa)
  324. reassign_mask = (pos_inds.unsqueeze(1) != pos_inds_after_paa).all(1)
  325. reassign_ids = pos_inds[reassign_mask]
  326. label[reassign_ids] = self.num_classes
  327. label_weight[ignore_inds_after_paa] = 0
  328. bbox_weight[reassign_ids] = 0
  329. num_pos = len(pos_inds_after_paa)
  330. return label, label_weight, bbox_weight, num_pos
  331. def gmm_separation_scheme(self, gmm_assignment, scores, pos_inds_gmm):
  332. """A general separation scheme for gmm model.
  333. It separates a GMM distribution of candidate samples into three
  334. parts, 0 1 and uncertain areas, and you can implement other
  335. separation schemes by rewriting this function.
  336. Args:
  337. gmm_assignment (Tensor): The prediction of GMM which is of shape
  338. (num_samples,). The 0/1 value indicates the distribution
  339. that each sample comes from.
  340. scores (Tensor): The probability of sample coming from the
  341. fit GMM distribution. The tensor is of shape (num_samples,).
  342. pos_inds_gmm (Tensor): All the indexes of samples which are used
  343. to fit GMM model. The tensor is of shape (num_samples,)
  344. Returns:
  345. tuple[Tensor]: The indices of positive and ignored samples.
  346. - pos_inds_temp (Tensor): Indices of positive samples.
  347. - ignore_inds_temp (Tensor): Indices of ignore samples.
  348. """
  349. # The implementation is (c) in Fig.3 in origin paper instead of (b).
  350. # You can refer to issues such as
  351. # https://github.com/kkhoot/PAA/issues/8 and
  352. # https://github.com/kkhoot/PAA/issues/9.
  353. fgs = gmm_assignment == 0
  354. pos_inds_temp = fgs.new_tensor([], dtype=torch.long)
  355. ignore_inds_temp = fgs.new_tensor([], dtype=torch.long)
  356. if fgs.nonzero().numel():
  357. _, pos_thr_ind = scores[fgs].topk(1)
  358. pos_inds_temp = pos_inds_gmm[fgs][:pos_thr_ind + 1]
  359. ignore_inds_temp = pos_inds_gmm.new_tensor([])
  360. return pos_inds_temp, ignore_inds_temp
  361. def get_targets(
  362. self,
  363. anchor_list,
  364. valid_flag_list,
  365. gt_bboxes_list,
  366. img_metas,
  367. gt_bboxes_ignore_list=None,
  368. gt_labels_list=None,
  369. label_channels=1,
  370. unmap_outputs=True,
  371. ):
  372. """Get targets for PAA head.
  373. This method is almost the same as `AnchorHead.get_targets()`. We direct
  374. return the results from _get_targets_single instead map it to levels
  375. by images_to_levels function.
  376. Args:
  377. anchor_list (list[list[Tensor]]): Multi level anchors of each
  378. image. The outer list indicates images, and the inner list
  379. corresponds to feature levels of the image. Each element of
  380. the inner list is a tensor of shape (num_anchors, 4).
  381. valid_flag_list (list[list[Tensor]]): Multi level valid flags of
  382. each image. The outer list indicates images, and the inner list
  383. corresponds to feature levels of the image. Each element of
  384. the inner list is a tensor of shape (num_anchors, )
  385. gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
  386. img_metas (list[dict]): Meta info of each image.
  387. gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
  388. ignored.
  389. gt_labels_list (list[Tensor]): Ground truth labels of each box.
  390. label_channels (int): Channel of label.
  391. unmap_outputs (bool): Whether to map outputs back to the original
  392. set of anchors.
  393. Returns:
  394. tuple: Usually returns a tuple containing learning targets.
  395. - labels (list[Tensor]): Labels of all anchors, each with
  396. shape (num_anchors,).
  397. - label_weights (list[Tensor]): Label weights of all anchor.
  398. each with shape (num_anchors,).
  399. - bbox_targets (list[Tensor]): BBox targets of all anchors.
  400. each with shape (num_anchors, 4).
  401. - bbox_weights (list[Tensor]): BBox weights of all anchors.
  402. each with shape (num_anchors, 4).
  403. - pos_inds (list[Tensor]): Contains all index of positive
  404. sample in all anchor.
  405. - gt_inds (list[Tensor]): Contains all gt_index of positive
  406. sample in all anchor.
  407. """
  408. num_imgs = len(img_metas)
  409. assert len(anchor_list) == len(valid_flag_list) == num_imgs
  410. concat_anchor_list = []
  411. concat_valid_flag_list = []
  412. for i in range(num_imgs):
  413. assert len(anchor_list[i]) == len(valid_flag_list[i])
  414. concat_anchor_list.append(torch.cat(anchor_list[i]))
  415. concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))
  416. # compute targets for each image
  417. if gt_bboxes_ignore_list is None:
  418. gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
  419. if gt_labels_list is None:
  420. gt_labels_list = [None for _ in range(num_imgs)]
  421. results = multi_apply(
  422. self._get_targets_single,
  423. concat_anchor_list,
  424. concat_valid_flag_list,
  425. gt_bboxes_list,
  426. gt_bboxes_ignore_list,
  427. gt_labels_list,
  428. img_metas,
  429. label_channels=label_channels,
  430. unmap_outputs=unmap_outputs)
  431. (labels, label_weights, bbox_targets, bbox_weights, valid_pos_inds,
  432. valid_neg_inds, sampling_result) = results
  433. # Due to valid flag of anchors, we have to calculate the real pos_inds
  434. # in origin anchor set.
  435. pos_inds = []
  436. for i, single_labels in enumerate(labels):
  437. pos_mask = (0 <= single_labels) & (
  438. single_labels < self.num_classes)
  439. pos_inds.append(pos_mask.nonzero().view(-1))
  440. gt_inds = [item.pos_assigned_gt_inds for item in sampling_result]
  441. return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
  442. gt_inds)
  443. def _get_targets_single(self,
  444. flat_anchors,
  445. valid_flags,
  446. gt_bboxes,
  447. gt_bboxes_ignore,
  448. gt_labels,
  449. img_meta,
  450. label_channels=1,
  451. unmap_outputs=True):
  452. """Compute regression and classification targets for anchors in a
  453. single image.
  454. This method is same as `AnchorHead._get_targets_single()`.
  455. """
  456. assert unmap_outputs, 'We must map outputs back to the original' \
  457. 'set of anchors in PAAhead'
  458. return super(ATSSHead, self)._get_targets_single(
  459. flat_anchors,
  460. valid_flags,
  461. gt_bboxes,
  462. gt_bboxes_ignore,
  463. gt_labels,
  464. img_meta,
  465. label_channels=1,
  466. unmap_outputs=True)
  467. @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
  468. def get_bboxes(self,
  469. cls_scores,
  470. bbox_preds,
  471. score_factors=None,
  472. img_metas=None,
  473. cfg=None,
  474. rescale=False,
  475. with_nms=True,
  476. **kwargs):
  477. assert with_nms, 'PAA only supports "with_nms=True" now and it ' \
  478. 'means PAAHead does not support ' \
  479. 'test-time augmentation'
  480. return super(ATSSHead, self).get_bboxes(cls_scores, bbox_preds,
  481. score_factors, img_metas, cfg,
  482. rescale, with_nms, **kwargs)
  483. def _get_bboxes_single(self,
  484. cls_score_list,
  485. bbox_pred_list,
  486. score_factor_list,
  487. mlvl_priors,
  488. img_meta,
  489. cfg,
  490. rescale=False,
  491. with_nms=True,
  492. **kwargs):
  493. """Transform outputs of a single image into bbox predictions.
  494. Args:
  495. cls_score_list (list[Tensor]): Box scores from all scale
  496. levels of a single image, each item has shape
  497. (num_priors * num_classes, H, W).
  498. bbox_pred_list (list[Tensor]): Box energies / deltas from
  499. all scale levels of a single image, each item has shape
  500. (num_priors * 4, H, W).
  501. score_factor_list (list[Tensor]): Score factors from all scale
  502. levels of a single image, each item has shape
  503. (num_priors * 1, H, W).
  504. mlvl_priors (list[Tensor]): Each element in the list is
  505. the priors of a single level in feature pyramid, has shape
  506. (num_priors, 4).
  507. img_meta (dict): Image meta info.
  508. cfg (mmcv.Config): Test / postprocessing configuration,
  509. if None, test_cfg would be used.
  510. rescale (bool): If True, return boxes in original image space.
  511. Default: False.
  512. with_nms (bool): If True, do nms before return boxes.
  513. Default: True.
  514. Returns:
  515. tuple[Tensor]: Results of detected bboxes and labels. If with_nms
  516. is False and mlvl_score_factor is None, return mlvl_bboxes and
  517. mlvl_scores, else return mlvl_bboxes, mlvl_scores and
  518. mlvl_score_factor. Usually with_nms is False is used for aug
  519. test. If with_nms is True, then return the following format
  520. - det_bboxes (Tensor): Predicted bboxes with shape \
  521. [num_bboxes, 5], where the first 4 columns are bounding \
  522. box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
  523. column are scores between 0 and 1.
  524. - det_labels (Tensor): Predicted labels of the corresponding \
  525. box with shape [num_bboxes].
  526. """
  527. cfg = self.test_cfg if cfg is None else cfg
  528. img_shape = img_meta['img_shape']
  529. nms_pre = cfg.get('nms_pre', -1)
  530. mlvl_bboxes = []
  531. mlvl_scores = []
  532. mlvl_score_factors = []
  533. for level_idx, (cls_score, bbox_pred, score_factor, priors) in \
  534. enumerate(zip(cls_score_list, bbox_pred_list,
  535. score_factor_list, mlvl_priors)):
  536. assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
  537. scores = cls_score.permute(1, 2, 0).reshape(
  538. -1, self.cls_out_channels).sigmoid()
  539. bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
  540. score_factor = score_factor.permute(1, 2, 0).reshape(-1).sigmoid()
  541. if 0 < nms_pre < scores.shape[0]:
  542. max_scores, _ = (scores *
  543. score_factor[:, None]).sqrt().max(dim=1)
  544. _, topk_inds = max_scores.topk(nms_pre)
  545. priors = priors[topk_inds, :]
  546. bbox_pred = bbox_pred[topk_inds, :]
  547. scores = scores[topk_inds, :]
  548. score_factor = score_factor[topk_inds]
  549. bboxes = self.bbox_coder.decode(
  550. priors, bbox_pred, max_shape=img_shape)
  551. mlvl_bboxes.append(bboxes)
  552. mlvl_scores.append(scores)
  553. mlvl_score_factors.append(score_factor)
  554. return self._bbox_post_process(mlvl_scores, mlvl_bboxes,
  555. img_meta['scale_factor'], cfg, rescale,
  556. with_nms, mlvl_score_factors, **kwargs)
  557. def _bbox_post_process(self,
  558. mlvl_scores,
  559. mlvl_bboxes,
  560. scale_factor,
  561. cfg,
  562. rescale=False,
  563. with_nms=True,
  564. mlvl_score_factors=None,
  565. **kwargs):
  566. """bbox post-processing method.
  567. The boxes would be rescaled to the original image scale and do
  568. the nms operation. Usually with_nms is False is used for aug test.
  569. Args:
  570. mlvl_scores (list[Tensor]): Box scores from all scale
  571. levels of a single image, each item has shape
  572. (num_bboxes, num_class).
  573. mlvl_bboxes (list[Tensor]): Decoded bboxes from all scale
  574. levels of a single image, each item has shape (num_bboxes, 4).
  575. scale_factor (ndarray, optional): Scale factor of the image arange
  576. as (w_scale, h_scale, w_scale, h_scale).
  577. cfg (mmcv.Config): Test / postprocessing configuration,
  578. if None, test_cfg would be used.
  579. rescale (bool): If True, return boxes in original image space.
  580. Default: False.
  581. with_nms (bool): If True, do nms before return boxes.
  582. Default: True.
  583. mlvl_score_factors (list[Tensor], optional): Score factor from
  584. all scale levels of a single image, each item has shape
  585. (num_bboxes, ). Default: None.
  586. Returns:
  587. tuple[Tensor]: Results of detected bboxes and labels. If with_nms
  588. is False and mlvl_score_factor is None, return mlvl_bboxes and
  589. mlvl_scores, else return mlvl_bboxes, mlvl_scores and
  590. mlvl_score_factor. Usually with_nms is False is used for aug
  591. test. If with_nms is True, then return the following format
  592. - det_bboxes (Tensor): Predicted bboxes with shape \
  593. [num_bboxes, 5], where the first 4 columns are bounding \
  594. box positions (tl_x, tl_y, br_x, br_y) and the 5-th \
  595. column are scores between 0 and 1.
  596. - det_labels (Tensor): Predicted labels of the corresponding \
  597. box with shape [num_bboxes].
  598. """
  599. mlvl_bboxes = torch.cat(mlvl_bboxes)
  600. if rescale:
  601. mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
  602. mlvl_scores = torch.cat(mlvl_scores)
  603. # Add a dummy background class to the backend when using sigmoid
  604. # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
  605. # BG cat_id: num_class
  606. padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
  607. mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
  608. mlvl_iou_preds = torch.cat(mlvl_score_factors)
  609. mlvl_nms_scores = (mlvl_scores * mlvl_iou_preds[:, None]).sqrt()
  610. det_bboxes, det_labels = multiclass_nms(
  611. mlvl_bboxes,
  612. mlvl_nms_scores,
  613. cfg.score_thr,
  614. cfg.nms,
  615. cfg.max_per_img,
  616. score_factors=None)
  617. if self.with_score_voting and len(det_bboxes) > 0:
  618. det_bboxes, det_labels = self.score_voting(det_bboxes, det_labels,
  619. mlvl_bboxes,
  620. mlvl_nms_scores,
  621. cfg.score_thr)
  622. return det_bboxes, det_labels
  623. def score_voting(self, det_bboxes, det_labels, mlvl_bboxes,
  624. mlvl_nms_scores, score_thr):
  625. """Implementation of score voting method works on each remaining boxes
  626. after NMS procedure.
  627. Args:
  628. det_bboxes (Tensor): Remaining boxes after NMS procedure,
  629. with shape (k, 5), each dimension means
  630. (x1, y1, x2, y2, score).
  631. det_labels (Tensor): The label of remaining boxes, with shape
  632. (k, 1),Labels are 0-based.
  633. mlvl_bboxes (Tensor): All boxes before the NMS procedure,
  634. with shape (num_anchors,4).
  635. mlvl_nms_scores (Tensor): The scores of all boxes which is used
  636. in the NMS procedure, with shape (num_anchors, num_class)
  637. score_thr (float): The score threshold of bboxes.
  638. Returns:
  639. tuple: Usually returns a tuple containing voting results.
  640. - det_bboxes_voted (Tensor): Remaining boxes after
  641. score voting procedure, with shape (k, 5), each
  642. dimension means (x1, y1, x2, y2, score).
  643. - det_labels_voted (Tensor): Label of remaining bboxes
  644. after voting, with shape (num_anchors,).
  645. """
  646. candidate_mask = mlvl_nms_scores > score_thr
  647. candidate_mask_nonzeros = candidate_mask.nonzero(as_tuple=False)
  648. candidate_inds = candidate_mask_nonzeros[:, 0]
  649. candidate_labels = candidate_mask_nonzeros[:, 1]
  650. candidate_bboxes = mlvl_bboxes[candidate_inds]
  651. candidate_scores = mlvl_nms_scores[candidate_mask]
  652. det_bboxes_voted = []
  653. det_labels_voted = []
  654. for cls in range(self.cls_out_channels):
  655. candidate_cls_mask = candidate_labels == cls
  656. if not candidate_cls_mask.any():
  657. continue
  658. candidate_cls_scores = candidate_scores[candidate_cls_mask]
  659. candidate_cls_bboxes = candidate_bboxes[candidate_cls_mask]
  660. det_cls_mask = det_labels == cls
  661. det_cls_bboxes = det_bboxes[det_cls_mask].view(
  662. -1, det_bboxes.size(-1))
  663. det_candidate_ious = bbox_overlaps(det_cls_bboxes[:, :4],
  664. candidate_cls_bboxes)
  665. for det_ind in range(len(det_cls_bboxes)):
  666. single_det_ious = det_candidate_ious[det_ind]
  667. pos_ious_mask = single_det_ious > 0.01
  668. pos_ious = single_det_ious[pos_ious_mask]
  669. pos_bboxes = candidate_cls_bboxes[pos_ious_mask]
  670. pos_scores = candidate_cls_scores[pos_ious_mask]
  671. pis = (torch.exp(-(1 - pos_ious)**2 / 0.025) *
  672. pos_scores)[:, None]
  673. voted_box = torch.sum(
  674. pis * pos_bboxes, dim=0) / torch.sum(
  675. pis, dim=0)
  676. voted_score = det_cls_bboxes[det_ind][-1:][None, :]
  677. det_bboxes_voted.append(
  678. torch.cat((voted_box[None, :], voted_score), dim=1))
  679. det_labels_voted.append(cls)
  680. det_bboxes_voted = torch.cat(det_bboxes_voted, dim=0)
  681. det_labels_voted = det_labels.new_tensor(det_labels_voted)
  682. return det_bboxes_voted, det_labels_voted

No Description

Contributors (3)