You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

standard_roi_head.py 17 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
  4. from ..builder import HEADS, build_head, build_roi_extractor
  5. from .base_roi_head import BaseRoIHead
  6. from .test_mixins import BBoxTestMixin, MaskTestMixin
  7. @HEADS.register_module()
  8. class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
  9. """Simplest base roi head including one bbox head and one mask head."""
  10. def init_assigner_sampler(self):
  11. """Initialize assigner and sampler."""
  12. self.bbox_assigner = None
  13. self.bbox_sampler = None
  14. if self.train_cfg:
  15. self.bbox_assigner = build_assigner(self.train_cfg.assigner)
  16. self.bbox_sampler = build_sampler(
  17. self.train_cfg.sampler, context=self)
  18. def init_bbox_head(self, bbox_roi_extractor, bbox_head):
  19. """Initialize ``bbox_head``"""
  20. self.bbox_roi_extractor = build_roi_extractor(bbox_roi_extractor)
  21. self.bbox_head = build_head(bbox_head)
  22. def init_mask_head(self, mask_roi_extractor, mask_head):
  23. """Initialize ``mask_head``"""
  24. if mask_roi_extractor is not None:
  25. self.mask_roi_extractor = build_roi_extractor(mask_roi_extractor)
  26. self.share_roi_extractor = False
  27. else:
  28. self.share_roi_extractor = True
  29. self.mask_roi_extractor = self.bbox_roi_extractor
  30. self.mask_head = build_head(mask_head)
  31. def forward_dummy(self, x, proposals):
  32. """Dummy forward function."""
  33. # bbox head
  34. outs = ()
  35. rois = bbox2roi([proposals])
  36. if self.with_bbox:
  37. bbox_results = self._bbox_forward(x, rois)
  38. outs = outs + (bbox_results['cls_score'],
  39. bbox_results['bbox_pred'])
  40. # mask head
  41. if self.with_mask:
  42. mask_rois = rois[:100]
  43. mask_results = self._mask_forward(x, mask_rois)
  44. outs = outs + (mask_results['mask_pred'], )
  45. return outs
  46. def forward_train(self,
  47. x,
  48. img_metas,
  49. proposal_list,
  50. gt_bboxes,
  51. gt_labels,
  52. gt_bboxes_ignore=None,
  53. gt_masks=None,
  54. **kwargs):
  55. """
  56. Args:
  57. x (list[Tensor]): list of multi-level img features.
  58. img_metas (list[dict]): list of image info dict where each dict
  59. has: 'img_shape', 'scale_factor', 'flip', and may also contain
  60. 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
  61. For details on the values of these keys see
  62. `mmdet/datasets/pipelines/formatting.py:Collect`.
  63. proposals (list[Tensors]): list of region proposals.
  64. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  65. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  66. gt_labels (list[Tensor]): class indices corresponding to each box
  67. gt_bboxes_ignore (None | list[Tensor]): specify which bounding
  68. boxes can be ignored when computing the loss.
  69. gt_masks (None | Tensor) : true segmentation masks for each box
  70. used if the architecture supports a segmentation task.
  71. Returns:
  72. dict[str, Tensor]: a dictionary of loss components
  73. """
  74. # assign gts and sample proposals
  75. if self.with_bbox or self.with_mask:
  76. num_imgs = len(img_metas)
  77. if gt_bboxes_ignore is None:
  78. gt_bboxes_ignore = [None for _ in range(num_imgs)]
  79. sampling_results = []
  80. for i in range(num_imgs):
  81. assign_result = self.bbox_assigner.assign(
  82. proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
  83. gt_labels[i])
  84. sampling_result = self.bbox_sampler.sample(
  85. assign_result,
  86. proposal_list[i],
  87. gt_bboxes[i],
  88. gt_labels[i],
  89. feats=[lvl_feat[i][None] for lvl_feat in x])
  90. sampling_results.append(sampling_result)
  91. losses = dict()
  92. # bbox head forward and loss
  93. if self.with_bbox:
  94. bbox_results = self._bbox_forward_train(x, sampling_results,
  95. gt_bboxes, gt_labels,
  96. img_metas)
  97. losses.update(bbox_results['loss_bbox'])
  98. # mask head forward and loss
  99. if self.with_mask:
  100. mask_results = self._mask_forward_train(x, sampling_results,
  101. bbox_results['bbox_feats'],
  102. gt_masks, img_metas)
  103. losses.update(mask_results['loss_mask'])
  104. return losses
  105. def _bbox_forward(self, x, rois):
  106. """Box head forward function used in both training and testing."""
  107. # TODO: a more flexible way to decide which feature maps to use
  108. bbox_feats = self.bbox_roi_extractor(
  109. x[:self.bbox_roi_extractor.num_inputs], rois)
  110. if self.with_shared_head:
  111. bbox_feats = self.shared_head(bbox_feats)
  112. cls_score, bbox_pred = self.bbox_head(bbox_feats)
  113. bbox_results = dict(
  114. cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats)
  115. return bbox_results
  116. def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
  117. img_metas):
  118. """Run forward function and calculate loss for box head in training."""
  119. rois = bbox2roi([res.bboxes for res in sampling_results])
  120. bbox_results = self._bbox_forward(x, rois)
  121. bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes,
  122. gt_labels, self.train_cfg)
  123. loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
  124. bbox_results['bbox_pred'], rois,
  125. *bbox_targets)
  126. bbox_results.update(loss_bbox=loss_bbox)
  127. return bbox_results
  128. def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks,
  129. img_metas):
  130. """Run forward function and calculate loss for mask head in
  131. training."""
  132. if not self.share_roi_extractor:
  133. pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
  134. mask_results = self._mask_forward(x, pos_rois)
  135. else:
  136. pos_inds = []
  137. device = bbox_feats.device
  138. for res in sampling_results:
  139. pos_inds.append(
  140. torch.ones(
  141. res.pos_bboxes.shape[0],
  142. device=device,
  143. dtype=torch.uint8))
  144. pos_inds.append(
  145. torch.zeros(
  146. res.neg_bboxes.shape[0],
  147. device=device,
  148. dtype=torch.uint8))
  149. pos_inds = torch.cat(pos_inds)
  150. mask_results = self._mask_forward(
  151. x, pos_inds=pos_inds, bbox_feats=bbox_feats)
  152. mask_targets = self.mask_head.get_targets(sampling_results, gt_masks,
  153. self.train_cfg)
  154. pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
  155. loss_mask = self.mask_head.loss(mask_results['mask_pred'],
  156. mask_targets, pos_labels)
  157. mask_results.update(loss_mask=loss_mask, mask_targets=mask_targets)
  158. return mask_results
  159. def _mask_forward(self, x, rois=None, pos_inds=None, bbox_feats=None):
  160. """Mask head forward function used in both training and testing."""
  161. assert ((rois is not None) ^
  162. (pos_inds is not None and bbox_feats is not None))
  163. if rois is not None:
  164. mask_feats = self.mask_roi_extractor(
  165. x[:self.mask_roi_extractor.num_inputs], rois)
  166. if self.with_shared_head:
  167. mask_feats = self.shared_head(mask_feats)
  168. else:
  169. assert bbox_feats is not None
  170. mask_feats = bbox_feats[pos_inds]
  171. mask_pred = self.mask_head(mask_feats)
  172. mask_results = dict(mask_pred=mask_pred, mask_feats=mask_feats)
  173. return mask_results
  174. async def async_simple_test(self,
  175. x,
  176. proposal_list,
  177. img_metas,
  178. proposals=None,
  179. rescale=False):
  180. """Async test without augmentation."""
  181. assert self.with_bbox, 'Bbox head must be implemented.'
  182. det_bboxes, det_labels = await self.async_test_bboxes(
  183. x, img_metas, proposal_list, self.test_cfg, rescale=rescale)
  184. bbox_results = bbox2result(det_bboxes, det_labels,
  185. self.bbox_head.num_classes)
  186. if not self.with_mask:
  187. return bbox_results
  188. else:
  189. segm_results = await self.async_test_mask(
  190. x,
  191. img_metas,
  192. det_bboxes,
  193. det_labels,
  194. rescale=rescale,
  195. mask_test_cfg=self.test_cfg.get('mask'))
  196. return bbox_results, segm_results
  197. def simple_test(self,
  198. x,
  199. proposal_list,
  200. img_metas,
  201. proposals=None,
  202. rescale=False):
  203. """Test without augmentation.
  204. Args:
  205. x (tuple[Tensor]): Features from upstream network. Each
  206. has shape (batch_size, c, h, w).
  207. proposal_list (list(Tensor)): Proposals from rpn head.
  208. Each has shape (num_proposals, 5), last dimension
  209. 5 represent (x1, y1, x2, y2, score).
  210. img_metas (list[dict]): Meta information of images.
  211. rescale (bool): Whether to rescale the results to
  212. the original image. Default: True.
  213. Returns:
  214. list[list[np.ndarray]] or list[tuple]: When no mask branch,
  215. it is bbox results of each image and classes with type
  216. `list[list[np.ndarray]]`. The outer list
  217. corresponds to each image. The inner list
  218. corresponds to each class. When the model has mask branch,
  219. it contains bbox results and mask results.
  220. The outer list corresponds to each image, and first element
  221. of tuple is bbox results, second element is mask results.
  222. """
  223. assert self.with_bbox, 'Bbox head must be implemented.'
  224. det_bboxes, det_labels = self.simple_test_bboxes(
  225. x, img_metas, proposal_list, self.test_cfg, rescale=rescale)
  226. bbox_results = [
  227. bbox2result(det_bboxes[i], det_labels[i],
  228. self.bbox_head.num_classes)
  229. for i in range(len(det_bboxes))
  230. ]
  231. if not self.with_mask:
  232. return bbox_results
  233. else:
  234. segm_results = self.simple_test_mask(
  235. x, img_metas, det_bboxes, det_labels, rescale=rescale)
  236. return list(zip(bbox_results, segm_results))
  237. def aug_test(self, x, proposal_list, img_metas, rescale=False):
  238. """Test with augmentations.
  239. If rescale is False, then returned bboxes and masks will fit the scale
  240. of imgs[0].
  241. """
  242. det_bboxes, det_labels = self.aug_test_bboxes(x, img_metas,
  243. proposal_list,
  244. self.test_cfg)
  245. if rescale:
  246. _det_bboxes = det_bboxes
  247. else:
  248. _det_bboxes = det_bboxes.clone()
  249. _det_bboxes[:, :4] *= det_bboxes.new_tensor(
  250. img_metas[0][0]['scale_factor'])
  251. bbox_results = bbox2result(_det_bboxes, det_labels,
  252. self.bbox_head.num_classes)
  253. # det_bboxes always keep the original scale
  254. if self.with_mask:
  255. segm_results = self.aug_test_mask(x, img_metas, det_bboxes,
  256. det_labels)
  257. return [(bbox_results, segm_results)]
  258. else:
  259. return [bbox_results]
  260. def onnx_export(self, x, proposals, img_metas, rescale=False):
  261. """Test without augmentation."""
  262. assert self.with_bbox, 'Bbox head must be implemented.'
  263. det_bboxes, det_labels = self.bbox_onnx_export(
  264. x, img_metas, proposals, self.test_cfg, rescale=rescale)
  265. if not self.with_mask:
  266. return det_bboxes, det_labels
  267. else:
  268. segm_results = self.mask_onnx_export(
  269. x, img_metas, det_bboxes, det_labels, rescale=rescale)
  270. return det_bboxes, det_labels, segm_results
  271. def mask_onnx_export(self, x, img_metas, det_bboxes, det_labels, **kwargs):
  272. """Export mask branch to onnx which supports batch inference.
  273. Args:
  274. x (tuple[Tensor]): Feature maps of all scale level.
  275. img_metas (list[dict]): Image meta info.
  276. det_bboxes (Tensor): Bboxes and corresponding scores.
  277. has shape [N, num_bboxes, 5].
  278. det_labels (Tensor): class labels of
  279. shape [N, num_bboxes].
  280. Returns:
  281. Tensor: The segmentation results of shape [N, num_bboxes,
  282. image_height, image_width].
  283. """
  284. # image shapes of images in the batch
  285. if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
  286. raise RuntimeError('[ONNX Error] Can not record MaskHead '
  287. 'as it has not been executed this time')
  288. batch_size = det_bboxes.size(0)
  289. # if det_bboxes is rescaled to the original image size, we need to
  290. # rescale it back to the testing scale to obtain RoIs.
  291. det_bboxes = det_bboxes[..., :4]
  292. batch_index = torch.arange(
  293. det_bboxes.size(0), device=det_bboxes.device).float().view(
  294. -1, 1, 1).expand(det_bboxes.size(0), det_bboxes.size(1), 1)
  295. mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
  296. mask_rois = mask_rois.view(-1, 5)
  297. mask_results = self._mask_forward(x, mask_rois)
  298. mask_pred = mask_results['mask_pred']
  299. max_shape = img_metas[0]['img_shape_for_onnx']
  300. num_det = det_bboxes.shape[1]
  301. det_bboxes = det_bboxes.reshape(-1, 4)
  302. det_labels = det_labels.reshape(-1)
  303. segm_results = self.mask_head.onnx_export(mask_pred, det_bboxes,
  304. det_labels, self.test_cfg,
  305. max_shape)
  306. segm_results = segm_results.reshape(batch_size, num_det, max_shape[0],
  307. max_shape[1])
  308. return segm_results
  309. def bbox_onnx_export(self, x, img_metas, proposals, rcnn_test_cfg,
  310. **kwargs):
  311. """Export bbox branch to onnx which supports batch inference.
  312. Args:
  313. x (tuple[Tensor]): Feature maps of all scale level.
  314. img_metas (list[dict]): Image meta info.
  315. proposals (Tensor): Region proposals with
  316. batch dimension, has shape [N, num_bboxes, 5].
  317. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
  318. Returns:
  319. tuple[Tensor, Tensor]: bboxes of shape [N, num_bboxes, 5]
  320. and class labels of shape [N, num_bboxes].
  321. """
  322. # get origin input shape to support onnx dynamic input shape
  323. assert len(
  324. img_metas
  325. ) == 1, 'Only support one input image while in exporting to ONNX'
  326. img_shapes = img_metas[0]['img_shape_for_onnx']
  327. rois = proposals
  328. batch_index = torch.arange(
  329. rois.size(0), device=rois.device).float().view(-1, 1, 1).expand(
  330. rois.size(0), rois.size(1), 1)
  331. rois = torch.cat([batch_index, rois[..., :4]], dim=-1)
  332. batch_size = rois.shape[0]
  333. num_proposals_per_img = rois.shape[1]
  334. # Eliminate the batch dimension
  335. rois = rois.view(-1, 5)
  336. bbox_results = self._bbox_forward(x, rois)
  337. cls_score = bbox_results['cls_score']
  338. bbox_pred = bbox_results['bbox_pred']
  339. # Recover the batch dimension
  340. rois = rois.reshape(batch_size, num_proposals_per_img, rois.size(-1))
  341. cls_score = cls_score.reshape(batch_size, num_proposals_per_img,
  342. cls_score.size(-1))
  343. bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img,
  344. bbox_pred.size(-1))
  345. det_bboxes, det_labels = self.bbox_head.onnx_export(
  346. rois, cls_score, bbox_pred, img_shapes, cfg=rcnn_test_cfg)
  347. return det_bboxes, det_labels

No Description

Contributors (2)