You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sparse_rcnn.py 4.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from ..builder import DETECTORS
  3. from .two_stage import TwoStageDetector
  4. @DETECTORS.register_module()
  5. class SparseRCNN(TwoStageDetector):
  6. r"""Implementation of `Sparse R-CNN: End-to-End Object Detection with
  7. Learnable Proposals <https://arxiv.org/abs/2011.12450>`_"""
  8. def __init__(self, *args, **kwargs):
  9. super(SparseRCNN, self).__init__(*args, **kwargs)
  10. assert self.with_rpn, 'Sparse R-CNN and QueryInst ' \
  11. 'do not support external proposals'
  12. def forward_train(self,
  13. img,
  14. img_metas,
  15. gt_bboxes,
  16. gt_labels,
  17. gt_bboxes_ignore=None,
  18. gt_masks=None,
  19. proposals=None,
  20. **kwargs):
  21. """Forward function of SparseR-CNN and QueryInst in train stage.
  22. Args:
  23. img (Tensor): of shape (N, C, H, W) encoding input images.
  24. Typically these should be mean centered and std scaled.
  25. img_metas (list[dict]): list of image info dict where each dict
  26. has: 'img_shape', 'scale_factor', 'flip', and may also contain
  27. 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
  28. For details on the values of these keys see
  29. :class:`mmdet.datasets.pipelines.Collect`.
  30. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  31. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  32. gt_labels (list[Tensor]): class indices corresponding to each box
  33. gt_bboxes_ignore (None | list[Tensor): specify which bounding
  34. boxes can be ignored when computing the loss.
  35. gt_masks (List[Tensor], optional) : Segmentation masks for
  36. each box. This is required to train QueryInst.
  37. proposals (List[Tensor], optional): override rpn proposals with
  38. custom proposals. Use when `with_rpn` is False.
  39. Returns:
  40. dict[str, Tensor]: a dictionary of loss components
  41. """
  42. assert proposals is None, 'Sparse R-CNN and QueryInst ' \
  43. 'do not support external proposals'
  44. x = self.extract_feat(img)
  45. proposal_boxes, proposal_features, imgs_whwh = \
  46. self.rpn_head.forward_train(x, img_metas)
  47. roi_losses = self.roi_head.forward_train(
  48. x,
  49. proposal_boxes,
  50. proposal_features,
  51. img_metas,
  52. gt_bboxes,
  53. gt_labels,
  54. gt_bboxes_ignore=gt_bboxes_ignore,
  55. gt_masks=gt_masks,
  56. imgs_whwh=imgs_whwh)
  57. return roi_losses
  58. def simple_test(self, img, img_metas, rescale=False):
  59. """Test function without test time augmentation.
  60. Args:
  61. imgs (list[torch.Tensor]): List of multiple images
  62. img_metas (list[dict]): List of image information.
  63. rescale (bool): Whether to rescale the results.
  64. Defaults to False.
  65. Returns:
  66. list[list[np.ndarray]]: BBox results of each image and classes.
  67. The outer list corresponds to each image. The inner list
  68. corresponds to each class.
  69. """
  70. x = self.extract_feat(img)
  71. proposal_boxes, proposal_features, imgs_whwh = \
  72. self.rpn_head.simple_test_rpn(x, img_metas)
  73. results = self.roi_head.simple_test(
  74. x,
  75. proposal_boxes,
  76. proposal_features,
  77. img_metas,
  78. imgs_whwh=imgs_whwh,
  79. rescale=rescale)
  80. return results
  81. def forward_dummy(self, img):
  82. """Used for computing network flops.
  83. See `mmdetection/tools/analysis_tools/get_flops.py`
  84. """
  85. # backbone
  86. x = self.extract_feat(img)
  87. # rpn
  88. num_imgs = len(img)
  89. dummy_img_metas = [
  90. dict(img_shape=(800, 1333, 3)) for _ in range(num_imgs)
  91. ]
  92. proposal_boxes, proposal_features, imgs_whwh = \
  93. self.rpn_head.simple_test_rpn(x, dummy_img_metas)
  94. # roi_head
  95. roi_outs = self.roi_head.forward_dummy(x, proposal_boxes,
  96. proposal_features,
  97. dummy_img_metas)
  98. return roi_outs

No Description

Contributors (1)