You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

embedding_rpn_head.py 4.6 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. from mmcv.runner import BaseModule
  5. from mmdet.models.builder import HEADS
  6. from ...core import bbox_cxcywh_to_xyxy
  7. @HEADS.register_module()
  8. class EmbeddingRPNHead(BaseModule):
  9. """RPNHead in the `Sparse R-CNN <https://arxiv.org/abs/2011.12450>`_ .
  10. Unlike traditional RPNHead, this module does not need FPN input, but just
  11. decode `init_proposal_bboxes` and expand the first dimension of
  12. `init_proposal_bboxes` and `init_proposal_features` to the batch_size.
  13. Args:
  14. num_proposals (int): Number of init_proposals. Default 100.
  15. proposal_feature_channel (int): Channel number of
  16. init_proposal_feature. Defaults to 256.
  17. init_cfg (dict or list[dict], optional): Initialization config dict.
  18. Default: None
  19. """
  20. def __init__(self,
  21. num_proposals=100,
  22. proposal_feature_channel=256,
  23. init_cfg=None,
  24. **kwargs):
  25. assert init_cfg is None, 'To prevent abnormal initialization ' \
  26. 'behavior, init_cfg is not allowed to be set'
  27. super(EmbeddingRPNHead, self).__init__(init_cfg)
  28. self.num_proposals = num_proposals
  29. self.proposal_feature_channel = proposal_feature_channel
  30. self._init_layers()
  31. def _init_layers(self):
  32. """Initialize a sparse set of proposal boxes and proposal features."""
  33. self.init_proposal_bboxes = nn.Embedding(self.num_proposals, 4)
  34. self.init_proposal_features = nn.Embedding(
  35. self.num_proposals, self.proposal_feature_channel)
  36. def init_weights(self):
  37. """Initialize the init_proposal_bboxes as normalized.
  38. [c_x, c_y, w, h], and we initialize it to the size of the entire
  39. image.
  40. """
  41. super(EmbeddingRPNHead, self).init_weights()
  42. nn.init.constant_(self.init_proposal_bboxes.weight[:, :2], 0.5)
  43. nn.init.constant_(self.init_proposal_bboxes.weight[:, 2:], 1)
  44. def _decode_init_proposals(self, imgs, img_metas):
  45. """Decode init_proposal_bboxes according to the size of images and
  46. expand dimension of init_proposal_features to batch_size.
  47. Args:
  48. imgs (list[Tensor]): List of FPN features.
  49. img_metas (list[dict]): List of meta-information of
  50. images. Need the img_shape to decode the init_proposals.
  51. Returns:
  52. Tuple(Tensor):
  53. - proposals (Tensor): Decoded proposal bboxes,
  54. has shape (batch_size, num_proposals, 4).
  55. - init_proposal_features (Tensor): Expanded proposal
  56. features, has shape
  57. (batch_size, num_proposals, proposal_feature_channel).
  58. - imgs_whwh (Tensor): Tensor with shape
  59. (batch_size, 4), the dimension means
  60. [img_width, img_height, img_width, img_height].
  61. """
  62. proposals = self.init_proposal_bboxes.weight.clone()
  63. proposals = bbox_cxcywh_to_xyxy(proposals)
  64. num_imgs = len(imgs[0])
  65. imgs_whwh = []
  66. for meta in img_metas:
  67. h, w, _ = meta['img_shape']
  68. imgs_whwh.append(imgs[0].new_tensor([[w, h, w, h]]))
  69. imgs_whwh = torch.cat(imgs_whwh, dim=0)
  70. imgs_whwh = imgs_whwh[:, None, :]
  71. # imgs_whwh has shape (batch_size, 1, 4)
  72. # The shape of proposals change from (num_proposals, 4)
  73. # to (batch_size ,num_proposals, 4)
  74. proposals = proposals * imgs_whwh
  75. init_proposal_features = self.init_proposal_features.weight.clone()
  76. init_proposal_features = init_proposal_features[None].expand(
  77. num_imgs, *init_proposal_features.size())
  78. return proposals, init_proposal_features, imgs_whwh
  79. def forward_dummy(self, img, img_metas):
  80. """Dummy forward function.
  81. Used in flops calculation.
  82. """
  83. return self._decode_init_proposals(img, img_metas)
  84. def forward_train(self, img, img_metas):
  85. """Forward function in training stage."""
  86. return self._decode_init_proposals(img, img_metas)
  87. def simple_test_rpn(self, img, img_metas):
  88. """Forward function in testing stage."""
  89. return self._decode_init_proposals(img, img_metas)
  90. def simple_test(self, img, img_metas):
  91. """Forward function in testing stage."""
  92. raise NotImplementedError
  93. def aug_test_rpn(self, feats, img_metas):
  94. raise NotImplementedError(
  95. 'EmbeddingRPNHead does not support test-time augmentation')

No Description

Contributors (2)