You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trident_roi_head.py 5.3 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. from mmcv.ops import batched_nms
  4. from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes,
  5. multiclass_nms)
  6. from mmdet.models.roi_heads.standard_roi_head import StandardRoIHead
  7. from ..builder import HEADS
  8. @HEADS.register_module()
  9. class TridentRoIHead(StandardRoIHead):
  10. """Trident roi head.
  11. Args:
  12. num_branch (int): Number of branches in TridentNet.
  13. test_branch_idx (int): In inference, all 3 branches will be used
  14. if `test_branch_idx==-1`, otherwise only branch with index
  15. `test_branch_idx` will be used.
  16. """
  17. def __init__(self, num_branch, test_branch_idx, **kwargs):
  18. self.num_branch = num_branch
  19. self.test_branch_idx = test_branch_idx
  20. super(TridentRoIHead, self).__init__(**kwargs)
  21. def merge_trident_bboxes(self, trident_det_bboxes, trident_det_labels):
  22. """Merge bbox predictions of each branch."""
  23. if trident_det_bboxes.numel() == 0:
  24. det_bboxes = trident_det_bboxes.new_zeros((0, 5))
  25. det_labels = trident_det_bboxes.new_zeros((0, ), dtype=torch.long)
  26. else:
  27. nms_bboxes = trident_det_bboxes[:, :4]
  28. nms_scores = trident_det_bboxes[:, 4].contiguous()
  29. nms_inds = trident_det_labels
  30. nms_cfg = self.test_cfg['nms']
  31. det_bboxes, keep = batched_nms(nms_bboxes, nms_scores, nms_inds,
  32. nms_cfg)
  33. det_labels = trident_det_labels[keep]
  34. if self.test_cfg['max_per_img'] > 0:
  35. det_labels = det_labels[:self.test_cfg['max_per_img']]
  36. det_bboxes = det_bboxes[:self.test_cfg['max_per_img']]
  37. return det_bboxes, det_labels
  38. def simple_test(self,
  39. x,
  40. proposal_list,
  41. img_metas,
  42. proposals=None,
  43. rescale=False):
  44. """Test without augmentation as follows:
  45. 1. Compute prediction bbox and label per branch.
  46. 2. Merge predictions of each branch according to scores of
  47. bboxes, i.e., bboxes with higher score are kept to give
  48. top-k prediction.
  49. """
  50. assert self.with_bbox, 'Bbox head must be implemented.'
  51. det_bboxes_list, det_labels_list = self.simple_test_bboxes(
  52. x, img_metas, proposal_list, self.test_cfg, rescale=rescale)
  53. num_branch = self.num_branch if self.test_branch_idx == -1 else 1
  54. for _ in range(len(det_bboxes_list)):
  55. if det_bboxes_list[_].shape[0] == 0:
  56. det_bboxes_list[_] = det_bboxes_list[_].new_empty((0, 5))
  57. det_bboxes, det_labels = [], []
  58. for i in range(len(img_metas) // num_branch):
  59. det_result = self.merge_trident_bboxes(
  60. torch.cat(det_bboxes_list[i * num_branch:(i + 1) *
  61. num_branch]),
  62. torch.cat(det_labels_list[i * num_branch:(i + 1) *
  63. num_branch]))
  64. det_bboxes.append(det_result[0])
  65. det_labels.append(det_result[1])
  66. bbox_results = [
  67. bbox2result(det_bboxes[i], det_labels[i],
  68. self.bbox_head.num_classes)
  69. for i in range(len(det_bboxes))
  70. ]
  71. return bbox_results
  72. def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
  73. """Test det bboxes with test time augmentation."""
  74. aug_bboxes = []
  75. aug_scores = []
  76. for x, img_meta in zip(feats, img_metas):
  77. # only one image in the batch
  78. img_shape = img_meta[0]['img_shape']
  79. scale_factor = img_meta[0]['scale_factor']
  80. flip = img_meta[0]['flip']
  81. flip_direction = img_meta[0]['flip_direction']
  82. trident_bboxes, trident_scores = [], []
  83. for branch_idx in range(len(proposal_list)):
  84. proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
  85. scale_factor, flip, flip_direction)
  86. rois = bbox2roi([proposals])
  87. bbox_results = self._bbox_forward(x, rois)
  88. bboxes, scores = self.bbox_head.get_bboxes(
  89. rois,
  90. bbox_results['cls_score'],
  91. bbox_results['bbox_pred'],
  92. img_shape,
  93. scale_factor,
  94. rescale=False,
  95. cfg=None)
  96. trident_bboxes.append(bboxes)
  97. trident_scores.append(scores)
  98. aug_bboxes.append(torch.cat(trident_bboxes, 0))
  99. aug_scores.append(torch.cat(trident_scores, 0))
  100. # after merging, bboxes will be rescaled to the original image size
  101. merged_bboxes, merged_scores = merge_aug_bboxes(
  102. aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
  103. det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
  104. rcnn_test_cfg.score_thr,
  105. rcnn_test_cfg.nms,
  106. rcnn_test_cfg.max_per_img)
  107. return det_bboxes, det_labels

No Description

Contributors (2)