You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import mmcv
  4. import torch
  5. from mmcv.image import tensor2imgs
  6. from mmdet.core import bbox_mapping
  7. from ..builder import DETECTORS, build_backbone, build_head, build_neck
  8. from .base import BaseDetector
  9. @DETECTORS.register_module()
  10. class RPN(BaseDetector):
  11. """Implementation of Region Proposal Network."""
  12. def __init__(self,
  13. backbone,
  14. neck,
  15. rpn_head,
  16. train_cfg,
  17. test_cfg,
  18. pretrained=None,
  19. init_cfg=None):
  20. super(RPN, self).__init__(init_cfg)
  21. if pretrained:
  22. warnings.warn('DeprecationWarning: pretrained is deprecated, '
  23. 'please use "init_cfg" instead')
  24. backbone.pretrained = pretrained
  25. self.backbone = build_backbone(backbone)
  26. self.neck = build_neck(neck) if neck is not None else None
  27. rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
  28. rpn_head.update(train_cfg=rpn_train_cfg)
  29. rpn_head.update(test_cfg=test_cfg.rpn)
  30. self.rpn_head = build_head(rpn_head)
  31. self.train_cfg = train_cfg
  32. self.test_cfg = test_cfg
  33. def extract_feat(self, img):
  34. """Extract features.
  35. Args:
  36. img (torch.Tensor): Image tensor with shape (n, c, h ,w).
  37. Returns:
  38. list[torch.Tensor]: Multi-level features that may have
  39. different resolutions.
  40. """
  41. x = self.backbone(img)
  42. if self.with_neck:
  43. x = self.neck(x)
  44. return x
  45. def forward_dummy(self, img):
  46. """Dummy forward function."""
  47. x = self.extract_feat(img)
  48. rpn_outs = self.rpn_head(x)
  49. return rpn_outs
  50. def forward_train(self,
  51. img,
  52. img_metas,
  53. gt_bboxes=None,
  54. gt_bboxes_ignore=None):
  55. """
  56. Args:
  57. img (Tensor): Input images of shape (N, C, H, W).
  58. Typically these should be mean centered and std scaled.
  59. img_metas (list[dict]): A List of image info dict where each dict
  60. has: 'img_shape', 'scale_factor', 'flip', and may also contain
  61. 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
  62. For details on the values of these keys see
  63. :class:`mmdet.datasets.pipelines.Collect`.
  64. gt_bboxes (list[Tensor]): Each item are the truth boxes for each
  65. image in [tl_x, tl_y, br_x, br_y] format.
  66. gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
  67. boxes can be ignored when computing the loss.
  68. Returns:
  69. dict[str, Tensor]: A dictionary of loss components.
  70. """
  71. if (isinstance(self.train_cfg.rpn, dict)
  72. and self.train_cfg.rpn.get('debug', False)):
  73. self.rpn_head.debug_imgs = tensor2imgs(img)
  74. x = self.extract_feat(img)
  75. losses = self.rpn_head.forward_train(x, img_metas, gt_bboxes, None,
  76. gt_bboxes_ignore)
  77. return losses
  78. def simple_test(self, img, img_metas, rescale=False):
  79. """Test function without test time augmentation.
  80. Args:
  81. imgs (list[torch.Tensor]): List of multiple images
  82. img_metas (list[dict]): List of image information.
  83. rescale (bool, optional): Whether to rescale the results.
  84. Defaults to False.
  85. Returns:
  86. list[np.ndarray]: proposals
  87. """
  88. x = self.extract_feat(img)
  89. # get origin input shape to onnx dynamic input shape
  90. if torch.onnx.is_in_onnx_export():
  91. img_shape = torch._shape_as_tensor(img)[2:]
  92. img_metas[0]['img_shape_for_onnx'] = img_shape
  93. proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
  94. if rescale:
  95. for proposals, meta in zip(proposal_list, img_metas):
  96. proposals[:, :4] /= proposals.new_tensor(meta['scale_factor'])
  97. if torch.onnx.is_in_onnx_export():
  98. return proposal_list
  99. return [proposal.cpu().numpy() for proposal in proposal_list]
  100. def aug_test(self, imgs, img_metas, rescale=False):
  101. """Test function with test time augmentation.
  102. Args:
  103. imgs (list[torch.Tensor]): List of multiple images
  104. img_metas (list[dict]): List of image information.
  105. rescale (bool, optional): Whether to rescale the results.
  106. Defaults to False.
  107. Returns:
  108. list[np.ndarray]: proposals
  109. """
  110. proposal_list = self.rpn_head.aug_test_rpn(
  111. self.extract_feats(imgs), img_metas)
  112. if not rescale:
  113. for proposals, img_meta in zip(proposal_list, img_metas[0]):
  114. img_shape = img_meta['img_shape']
  115. scale_factor = img_meta['scale_factor']
  116. flip = img_meta['flip']
  117. flip_direction = img_meta['flip_direction']
  118. proposals[:, :4] = bbox_mapping(proposals[:, :4], img_shape,
  119. scale_factor, flip,
  120. flip_direction)
  121. return [proposal.cpu().numpy() for proposal in proposal_list]
  122. def show_result(self, data, result, top_k=20, **kwargs):
  123. """Show RPN proposals on the image.
  124. Args:
  125. data (str or np.ndarray): Image filename or loaded image.
  126. result (Tensor or tuple): The results to draw over `img`
  127. bbox_result or (bbox_result, segm_result).
  128. top_k (int): Plot the first k bboxes only
  129. if set positive. Default: 20
  130. Returns:
  131. np.ndarray: The image with bboxes drawn on it.
  132. """
  133. if kwargs is not None:
  134. kwargs.pop('score_thr', None)
  135. kwargs.pop('text_color', None)
  136. kwargs['colors'] = kwargs.pop('bbox_color', 'green')
  137. mmcv.imshow_bboxes(data, result, top_k=top_k, **kwargs)

No Description

Contributors (3)