You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dense_test_mixins.py 8.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import sys
  3. from inspect import signature
  4. import torch
  5. from mmcv.ops import batched_nms
  6. from mmdet.core import bbox_mapping_back, merge_aug_proposals
  7. if sys.version_info >= (3, 7):
  8. from mmdet.utils.contextmanagers import completed
  9. class BBoxTestMixin(object):
  10. """Mixin class for testing det bboxes via DenseHead."""
  11. def simple_test_bboxes(self, feats, img_metas, rescale=False):
  12. """Test det bboxes without test-time augmentation, can be applied in
  13. DenseHead except for ``RPNHead`` and its variants, e.g., ``GARPNHead``,
  14. etc.
  15. Args:
  16. feats (tuple[torch.Tensor]): Multi-level features from the
  17. upstream network, each is a 4D-tensor.
  18. img_metas (list[dict]): List of image information.
  19. rescale (bool, optional): Whether to rescale the results.
  20. Defaults to False.
  21. Returns:
  22. list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
  23. The first item is ``bboxes`` with shape (n, 5),
  24. where 5 represent (tl_x, tl_y, br_x, br_y, score).
  25. The shape of the second tensor in the tuple is ``labels``
  26. with shape (n,)
  27. """
  28. outs = self.forward(feats)
  29. results_list = self.get_bboxes(
  30. *outs, img_metas=img_metas, rescale=rescale)
  31. return results_list
  32. def aug_test_bboxes(self, feats, img_metas, rescale=False):
  33. """Test det bboxes with test time augmentation, can be applied in
  34. DenseHead except for ``RPNHead`` and its variants, e.g., ``GARPNHead``,
  35. etc.
  36. Args:
  37. feats (list[Tensor]): the outer list indicates test-time
  38. augmentations and inner Tensor should have a shape NxCxHxW,
  39. which contains features for all images in the batch.
  40. img_metas (list[list[dict]]): the outer list indicates test-time
  41. augs (multiscale, flip, etc.) and the inner list indicates
  42. images in a batch. each dict has image information.
  43. rescale (bool, optional): Whether to rescale the results.
  44. Defaults to False.
  45. Returns:
  46. list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
  47. The first item is ``bboxes`` with shape (n, 5),
  48. where 5 represent (tl_x, tl_y, br_x, br_y, score).
  49. The shape of the second tensor in the tuple is ``labels``
  50. with shape (n,). The length of list should always be 1.
  51. """
  52. # check with_nms argument
  53. gb_sig = signature(self.get_bboxes)
  54. gb_args = [p.name for p in gb_sig.parameters.values()]
  55. gbs_sig = signature(self._get_bboxes_single)
  56. gbs_args = [p.name for p in gbs_sig.parameters.values()]
  57. assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \
  58. f'{self.__class__.__name__}' \
  59. ' does not support test-time augmentation'
  60. aug_bboxes = []
  61. aug_scores = []
  62. aug_labels = []
  63. for x, img_meta in zip(feats, img_metas):
  64. # only one image in the batch
  65. outs = self.forward(x)
  66. bbox_outputs = self.get_bboxes(
  67. *outs,
  68. img_metas=img_meta,
  69. cfg=self.test_cfg,
  70. rescale=False,
  71. with_nms=False)[0]
  72. aug_bboxes.append(bbox_outputs[0])
  73. aug_scores.append(bbox_outputs[1])
  74. if len(bbox_outputs) >= 3:
  75. aug_labels.append(bbox_outputs[2])
  76. # after merging, bboxes will be rescaled to the original image size
  77. merged_bboxes, merged_scores = self.merge_aug_bboxes(
  78. aug_bboxes, aug_scores, img_metas)
  79. merged_labels = torch.cat(aug_labels, dim=0) if aug_labels else None
  80. if merged_bboxes.numel() == 0:
  81. det_bboxes = torch.cat([merged_bboxes, merged_scores[:, None]], -1)
  82. return det_bboxes, merged_labels
  83. det_bboxes, keep_idxs = batched_nms(merged_bboxes, merged_scores,
  84. merged_labels, self.test_cfg.nms)
  85. det_bboxes = det_bboxes[:self.test_cfg.max_per_img]
  86. det_labels = merged_labels[keep_idxs][:self.test_cfg.max_per_img]
  87. if rescale:
  88. _det_bboxes = det_bboxes
  89. else:
  90. _det_bboxes = det_bboxes.clone()
  91. _det_bboxes[:, :4] *= det_bboxes.new_tensor(
  92. img_metas[0][0]['scale_factor'])
  93. return [
  94. (_det_bboxes, det_labels),
  95. ]
  96. def simple_test_rpn(self, x, img_metas):
  97. """Test without augmentation, only for ``RPNHead`` and its variants,
  98. e.g., ``GARPNHead``, etc.
  99. Args:
  100. x (tuple[Tensor]): Features from the upstream network, each is
  101. a 4D-tensor.
  102. img_metas (list[dict]): Meta info of each image.
  103. Returns:
  104. list[Tensor]: Proposals of each image, each item has shape (n, 5),
  105. where 5 represent (tl_x, tl_y, br_x, br_y, score).
  106. """
  107. rpn_outs = self(x)
  108. proposal_list = self.get_bboxes(*rpn_outs, img_metas=img_metas)
  109. return proposal_list
  110. def aug_test_rpn(self, feats, img_metas):
  111. """Test with augmentation for only for ``RPNHead`` and its variants,
  112. e.g., ``GARPNHead``, etc.
  113. Args:
  114. feats (tuple[Tensor]): Features from the upstream network, each is
  115. a 4D-tensor.
  116. img_metas (list[dict]): Meta info of each image.
  117. Returns:
  118. list[Tensor]: Proposals of each image, each item has shape (n, 5),
  119. where 5 represent (tl_x, tl_y, br_x, br_y, score).
  120. """
  121. samples_per_gpu = len(img_metas[0])
  122. aug_proposals = [[] for _ in range(samples_per_gpu)]
  123. for x, img_meta in zip(feats, img_metas):
  124. proposal_list = self.simple_test_rpn(x, img_meta)
  125. for i, proposals in enumerate(proposal_list):
  126. aug_proposals[i].append(proposals)
  127. # reorganize the order of 'img_metas' to match the dimensions
  128. # of 'aug_proposals'
  129. aug_img_metas = []
  130. for i in range(samples_per_gpu):
  131. aug_img_meta = []
  132. for j in range(len(img_metas)):
  133. aug_img_meta.append(img_metas[j][i])
  134. aug_img_metas.append(aug_img_meta)
  135. # after merging, proposals will be rescaled to the original image size
  136. merged_proposals = [
  137. merge_aug_proposals(proposals, aug_img_meta, self.test_cfg)
  138. for proposals, aug_img_meta in zip(aug_proposals, aug_img_metas)
  139. ]
  140. return merged_proposals
  141. if sys.version_info >= (3, 7):
  142. async def async_simple_test_rpn(self, x, img_metas):
  143. sleep_interval = self.test_cfg.pop('async_sleep_interval', 0.025)
  144. async with completed(
  145. __name__, 'rpn_head_forward',
  146. sleep_interval=sleep_interval):
  147. rpn_outs = self(x)
  148. proposal_list = self.get_bboxes(*rpn_outs, img_metas=img_metas)
  149. return proposal_list
  150. def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas):
  151. """Merge augmented detection bboxes and scores.
  152. Args:
  153. aug_bboxes (list[Tensor]): shape (n, 4*#class)
  154. aug_scores (list[Tensor] or None): shape (n, #class)
  155. img_shapes (list[Tensor]): shape (3, ).
  156. Returns:
  157. tuple[Tensor]: ``bboxes`` with shape (n,4), where
  158. 4 represent (tl_x, tl_y, br_x, br_y)
  159. and ``scores`` with shape (n,).
  160. """
  161. recovered_bboxes = []
  162. for bboxes, img_info in zip(aug_bboxes, img_metas):
  163. img_shape = img_info[0]['img_shape']
  164. scale_factor = img_info[0]['scale_factor']
  165. flip = img_info[0]['flip']
  166. flip_direction = img_info[0]['flip_direction']
  167. bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
  168. flip_direction)
  169. recovered_bboxes.append(bboxes)
  170. bboxes = torch.cat(recovered_bboxes, dim=0)
  171. if aug_scores is None:
  172. return bboxes
  173. else:
  174. scores = torch.cat(aug_scores, dim=0)
  175. return bboxes, scores

No Description

Contributors (3)