You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grid_roi_head.py 7.0 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. from mmdet.core import bbox2result, bbox2roi
  5. from ..builder import HEADS, build_head, build_roi_extractor
  6. from .standard_roi_head import StandardRoIHead
  7. @HEADS.register_module()
  8. class GridRoIHead(StandardRoIHead):
  9. """Grid roi head for Grid R-CNN.
  10. https://arxiv.org/abs/1811.12030
  11. """
  12. def __init__(self, grid_roi_extractor, grid_head, **kwargs):
  13. assert grid_head is not None
  14. super(GridRoIHead, self).__init__(**kwargs)
  15. if grid_roi_extractor is not None:
  16. self.grid_roi_extractor = build_roi_extractor(grid_roi_extractor)
  17. self.share_roi_extractor = False
  18. else:
  19. self.share_roi_extractor = True
  20. self.grid_roi_extractor = self.bbox_roi_extractor
  21. self.grid_head = build_head(grid_head)
  22. def _random_jitter(self, sampling_results, img_metas, amplitude=0.15):
  23. """Ramdom jitter positive proposals for training."""
  24. for sampling_result, img_meta in zip(sampling_results, img_metas):
  25. bboxes = sampling_result.pos_bboxes
  26. random_offsets = bboxes.new_empty(bboxes.shape[0], 4).uniform_(
  27. -amplitude, amplitude)
  28. # before jittering
  29. cxcy = (bboxes[:, 2:4] + bboxes[:, :2]) / 2
  30. wh = (bboxes[:, 2:4] - bboxes[:, :2]).abs()
  31. # after jittering
  32. new_cxcy = cxcy + wh * random_offsets[:, :2]
  33. new_wh = wh * (1 + random_offsets[:, 2:])
  34. # xywh to xyxy
  35. new_x1y1 = (new_cxcy - new_wh / 2)
  36. new_x2y2 = (new_cxcy + new_wh / 2)
  37. new_bboxes = torch.cat([new_x1y1, new_x2y2], dim=1)
  38. # clip bboxes
  39. max_shape = img_meta['img_shape']
  40. if max_shape is not None:
  41. new_bboxes[:, 0::2].clamp_(min=0, max=max_shape[1] - 1)
  42. new_bboxes[:, 1::2].clamp_(min=0, max=max_shape[0] - 1)
  43. sampling_result.pos_bboxes = new_bboxes
  44. return sampling_results
  45. def forward_dummy(self, x, proposals):
  46. """Dummy forward function."""
  47. # bbox head
  48. outs = ()
  49. rois = bbox2roi([proposals])
  50. if self.with_bbox:
  51. bbox_results = self._bbox_forward(x, rois)
  52. outs = outs + (bbox_results['cls_score'],
  53. bbox_results['bbox_pred'])
  54. # grid head
  55. grid_rois = rois[:100]
  56. grid_feats = self.grid_roi_extractor(
  57. x[:self.grid_roi_extractor.num_inputs], grid_rois)
  58. if self.with_shared_head:
  59. grid_feats = self.shared_head(grid_feats)
  60. grid_pred = self.grid_head(grid_feats)
  61. outs = outs + (grid_pred, )
  62. # mask head
  63. if self.with_mask:
  64. mask_rois = rois[:100]
  65. mask_results = self._mask_forward(x, mask_rois)
  66. outs = outs + (mask_results['mask_pred'], )
  67. return outs
  68. def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
  69. img_metas):
  70. """Run forward function and calculate loss for box head in training."""
  71. bbox_results = super(GridRoIHead,
  72. self)._bbox_forward_train(x, sampling_results,
  73. gt_bboxes, gt_labels,
  74. img_metas)
  75. # Grid head forward and loss
  76. sampling_results = self._random_jitter(sampling_results, img_metas)
  77. pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
  78. # GN in head does not support zero shape input
  79. if pos_rois.shape[0] == 0:
  80. return bbox_results
  81. grid_feats = self.grid_roi_extractor(
  82. x[:self.grid_roi_extractor.num_inputs], pos_rois)
  83. if self.with_shared_head:
  84. grid_feats = self.shared_head(grid_feats)
  85. # Accelerate training
  86. max_sample_num_grid = self.train_cfg.get('max_num_grid', 192)
  87. sample_idx = torch.randperm(
  88. grid_feats.shape[0])[:min(grid_feats.shape[0], max_sample_num_grid
  89. )]
  90. grid_feats = grid_feats[sample_idx]
  91. grid_pred = self.grid_head(grid_feats)
  92. grid_targets = self.grid_head.get_targets(sampling_results,
  93. self.train_cfg)
  94. grid_targets = grid_targets[sample_idx]
  95. loss_grid = self.grid_head.loss(grid_pred, grid_targets)
  96. bbox_results['loss_bbox'].update(loss_grid)
  97. return bbox_results
  98. def simple_test(self,
  99. x,
  100. proposal_list,
  101. img_metas,
  102. proposals=None,
  103. rescale=False):
  104. """Test without augmentation."""
  105. assert self.with_bbox, 'Bbox head must be implemented.'
  106. det_bboxes, det_labels = self.simple_test_bboxes(
  107. x, img_metas, proposal_list, self.test_cfg, rescale=False)
  108. # pack rois into bboxes
  109. grid_rois = bbox2roi([det_bbox[:, :4] for det_bbox in det_bboxes])
  110. if grid_rois.shape[0] != 0:
  111. grid_feats = self.grid_roi_extractor(
  112. x[:len(self.grid_roi_extractor.featmap_strides)], grid_rois)
  113. self.grid_head.test_mode = True
  114. grid_pred = self.grid_head(grid_feats)
  115. # split batch grid head prediction back to each image
  116. num_roi_per_img = tuple(len(det_bbox) for det_bbox in det_bboxes)
  117. grid_pred = {
  118. k: v.split(num_roi_per_img, 0)
  119. for k, v in grid_pred.items()
  120. }
  121. # apply bbox post-processing to each image individually
  122. bbox_results = []
  123. num_imgs = len(det_bboxes)
  124. for i in range(num_imgs):
  125. if det_bboxes[i].shape[0] == 0:
  126. bbox_results.append([
  127. np.zeros((0, 5), dtype=np.float32)
  128. for _ in range(self.bbox_head.num_classes)
  129. ])
  130. else:
  131. det_bbox = self.grid_head.get_bboxes(
  132. det_bboxes[i], grid_pred['fused'][i], [img_metas[i]])
  133. if rescale:
  134. det_bbox[:, :4] /= img_metas[i]['scale_factor']
  135. bbox_results.append(
  136. bbox2result(det_bbox, det_labels[i],
  137. self.bbox_head.num_classes))
  138. else:
  139. bbox_results = [[
  140. np.zeros((0, 5), dtype=np.float32)
  141. for _ in range(self.bbox_head.num_classes)
  142. ] for _ in range(len(det_bboxes))]
  143. if not self.with_mask:
  144. return bbox_results
  145. else:
  146. segm_results = self.simple_test_mask(
  147. x, img_metas, det_bboxes, det_labels, rescale=rescale)
  148. return list(zip(bbox_results, segm_results))

No Description

Contributors (3)