You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sparse_roi_head.py 19 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. from mmdet.core import bbox2result, bbox2roi, bbox_xyxy_to_cxcywh
  5. from mmdet.core.bbox.samplers import PseudoSampler
  6. from ..builder import HEADS
  7. from .cascade_roi_head import CascadeRoIHead
  8. @HEADS.register_module()
  9. class SparseRoIHead(CascadeRoIHead):
  10. r"""The RoIHead for `Sparse R-CNN: End-to-End Object Detection with
  11. Learnable Proposals <https://arxiv.org/abs/2011.12450>`_
  12. and `Instances as Queries <http://arxiv.org/abs/2105.01928>`_
  13. Args:
  14. num_stages (int): Number of stage whole iterative process.
  15. Defaults to 6.
  16. stage_loss_weights (Tuple[float]): The loss
  17. weight of each stage. By default all stages have
  18. the same weight 1.
  19. bbox_roi_extractor (dict): Config of box roi extractor.
  20. mask_roi_extractor (dict): Config of mask roi extractor.
  21. bbox_head (dict): Config of box head.
  22. mask_head (dict): Config of mask head.
  23. train_cfg (dict, optional): Configuration information in train stage.
  24. Defaults to None.
  25. test_cfg (dict, optional): Configuration information in test stage.
  26. Defaults to None.
  27. pretrained (str, optional): model pretrained path. Default: None
  28. init_cfg (dict or list[dict], optional): Initialization config dict.
  29. Default: None
  30. """
  31. def __init__(self,
  32. num_stages=6,
  33. stage_loss_weights=(1, 1, 1, 1, 1, 1),
  34. proposal_feature_channel=256,
  35. bbox_roi_extractor=dict(
  36. type='SingleRoIExtractor',
  37. roi_layer=dict(
  38. type='RoIAlign', output_size=7, sampling_ratio=2),
  39. out_channels=256,
  40. featmap_strides=[4, 8, 16, 32]),
  41. mask_roi_extractor=None,
  42. bbox_head=dict(
  43. type='DIIHead',
  44. num_classes=80,
  45. num_fcs=2,
  46. num_heads=8,
  47. num_cls_fcs=1,
  48. num_reg_fcs=3,
  49. feedforward_channels=2048,
  50. hidden_channels=256,
  51. dropout=0.0,
  52. roi_feat_size=7,
  53. ffn_act_cfg=dict(type='ReLU', inplace=True)),
  54. mask_head=None,
  55. train_cfg=None,
  56. test_cfg=None,
  57. pretrained=None,
  58. init_cfg=None):
  59. assert bbox_roi_extractor is not None
  60. assert bbox_head is not None
  61. assert len(stage_loss_weights) == num_stages
  62. self.num_stages = num_stages
  63. self.stage_loss_weights = stage_loss_weights
  64. self.proposal_feature_channel = proposal_feature_channel
  65. super(SparseRoIHead, self).__init__(
  66. num_stages,
  67. stage_loss_weights,
  68. bbox_roi_extractor=bbox_roi_extractor,
  69. mask_roi_extractor=mask_roi_extractor,
  70. bbox_head=bbox_head,
  71. mask_head=mask_head,
  72. train_cfg=train_cfg,
  73. test_cfg=test_cfg,
  74. pretrained=pretrained,
  75. init_cfg=init_cfg)
  76. # train_cfg would be None when run the test.py
  77. if train_cfg is not None:
  78. for stage in range(num_stages):
  79. assert isinstance(self.bbox_sampler[stage], PseudoSampler), \
  80. 'Sparse R-CNN and QueryInst only support `PseudoSampler`'
  81. def _bbox_forward(self, stage, x, rois, object_feats, img_metas):
  82. """Box head forward function used in both training and testing. Returns
  83. all regression, classification results and a intermediate feature.
  84. Args:
  85. stage (int): The index of current stage in
  86. iterative process.
  87. x (List[Tensor]): List of FPN features
  88. rois (Tensor): Rois in total batch. With shape (num_proposal, 5).
  89. the last dimension 5 represents (img_index, x1, y1, x2, y2).
  90. object_feats (Tensor): The object feature extracted from
  91. the previous stage.
  92. img_metas (dict): meta information of images.
  93. Returns:
  94. dict[str, Tensor]: a dictionary of bbox head outputs,
  95. Containing the following results:
  96. - cls_score (Tensor): The score of each class, has
  97. shape (batch_size, num_proposals, num_classes)
  98. when use focal loss or
  99. (batch_size, num_proposals, num_classes+1)
  100. otherwise.
  101. - decode_bbox_pred (Tensor): The regression results
  102. with shape (batch_size, num_proposal, 4).
  103. The last dimension 4 represents
  104. [tl_x, tl_y, br_x, br_y].
  105. - object_feats (Tensor): The object feature extracted
  106. from current stage
  107. - detach_cls_score_list (list[Tensor]): The detached
  108. classification results, length is batch_size, and
  109. each tensor has shape (num_proposal, num_classes).
  110. - detach_proposal_list (list[tensor]): The detached
  111. regression results, length is batch_size, and each
  112. tensor has shape (num_proposal, 4). The last
  113. dimension 4 represents [tl_x, tl_y, br_x, br_y].
  114. """
  115. num_imgs = len(img_metas)
  116. bbox_roi_extractor = self.bbox_roi_extractor[stage]
  117. bbox_head = self.bbox_head[stage]
  118. bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
  119. rois)
  120. cls_score, bbox_pred, object_feats, attn_feats = bbox_head(
  121. bbox_feats, object_feats)
  122. proposal_list = self.bbox_head[stage].refine_bboxes(
  123. rois,
  124. rois.new_zeros(len(rois)), # dummy arg
  125. bbox_pred.view(-1, bbox_pred.size(-1)),
  126. [rois.new_zeros(object_feats.size(1)) for _ in range(num_imgs)],
  127. img_metas)
  128. bbox_results = dict(
  129. cls_score=cls_score,
  130. decode_bbox_pred=torch.cat(proposal_list),
  131. object_feats=object_feats,
  132. attn_feats=attn_feats,
  133. # detach then use it in label assign
  134. detach_cls_score_list=[
  135. cls_score[i].detach() for i in range(num_imgs)
  136. ],
  137. detach_proposal_list=[item.detach() for item in proposal_list])
  138. return bbox_results
  139. def _mask_forward(self, stage, x, rois, attn_feats):
  140. """Mask head forward function used in both training and testing."""
  141. mask_roi_extractor = self.mask_roi_extractor[stage]
  142. mask_head = self.mask_head[stage]
  143. mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs],
  144. rois)
  145. # do not support caffe_c4 model anymore
  146. mask_pred = mask_head(mask_feats, attn_feats)
  147. mask_results = dict(mask_pred=mask_pred)
  148. return mask_results
  149. def _mask_forward_train(self, stage, x, attn_feats, sampling_results,
  150. gt_masks, rcnn_train_cfg):
  151. """Run forward function and calculate loss for mask head in
  152. training."""
  153. pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
  154. attn_feats = torch.cat([
  155. feats[res.pos_inds]
  156. for (feats, res) in zip(attn_feats, sampling_results)
  157. ])
  158. mask_results = self._mask_forward(stage, x, pos_rois, attn_feats)
  159. mask_targets = self.mask_head[stage].get_targets(
  160. sampling_results, gt_masks, rcnn_train_cfg)
  161. pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
  162. loss_mask = self.mask_head[stage].loss(mask_results['mask_pred'],
  163. mask_targets, pos_labels)
  164. mask_results.update(loss_mask)
  165. return mask_results
  166. def forward_train(self,
  167. x,
  168. proposal_boxes,
  169. proposal_features,
  170. img_metas,
  171. gt_bboxes,
  172. gt_labels,
  173. gt_bboxes_ignore=None,
  174. imgs_whwh=None,
  175. gt_masks=None):
  176. """Forward function in training stage.
  177. Args:
  178. x (list[Tensor]): list of multi-level img features.
  179. proposals (Tensor): Decoded proposal bboxes, has shape
  180. (batch_size, num_proposals, 4)
  181. proposal_features (Tensor): Expanded proposal
  182. features, has shape
  183. (batch_size, num_proposals, proposal_feature_channel)
  184. img_metas (list[dict]): list of image info dict where
  185. each dict has: 'img_shape', 'scale_factor', 'flip',
  186. and may also contain 'filename', 'ori_shape',
  187. 'pad_shape', and 'img_norm_cfg'. For details on the
  188. values of these keys see
  189. `mmdet/datasets/pipelines/formatting.py:Collect`.
  190. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  191. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  192. gt_labels (list[Tensor]): class indices corresponding to each box
  193. gt_bboxes_ignore (None | list[Tensor]): specify which bounding
  194. boxes can be ignored when computing the loss.
  195. imgs_whwh (Tensor): Tensor with shape (batch_size, 4),
  196. the dimension means
  197. [img_width,img_height, img_width, img_height].
  198. gt_masks (None | Tensor) : true segmentation masks for each box
  199. used if the architecture supports a segmentation task.
  200. Returns:
  201. dict[str, Tensor]: a dictionary of loss components of all stage.
  202. """
  203. num_imgs = len(img_metas)
  204. num_proposals = proposal_boxes.size(1)
  205. imgs_whwh = imgs_whwh.repeat(1, num_proposals, 1)
  206. all_stage_bbox_results = []
  207. proposal_list = [proposal_boxes[i] for i in range(len(proposal_boxes))]
  208. object_feats = proposal_features
  209. all_stage_loss = {}
  210. for stage in range(self.num_stages):
  211. rois = bbox2roi(proposal_list)
  212. bbox_results = self._bbox_forward(stage, x, rois, object_feats,
  213. img_metas)
  214. all_stage_bbox_results.append(bbox_results)
  215. if gt_bboxes_ignore is None:
  216. # TODO support ignore
  217. gt_bboxes_ignore = [None for _ in range(num_imgs)]
  218. sampling_results = []
  219. cls_pred_list = bbox_results['detach_cls_score_list']
  220. proposal_list = bbox_results['detach_proposal_list']
  221. for i in range(num_imgs):
  222. normalize_bbox_ccwh = bbox_xyxy_to_cxcywh(proposal_list[i] /
  223. imgs_whwh[i])
  224. assign_result = self.bbox_assigner[stage].assign(
  225. normalize_bbox_ccwh, cls_pred_list[i], gt_bboxes[i],
  226. gt_labels[i], img_metas[i])
  227. sampling_result = self.bbox_sampler[stage].sample(
  228. assign_result, proposal_list[i], gt_bboxes[i])
  229. sampling_results.append(sampling_result)
  230. bbox_targets = self.bbox_head[stage].get_targets(
  231. sampling_results, gt_bboxes, gt_labels, self.train_cfg[stage],
  232. True)
  233. cls_score = bbox_results['cls_score']
  234. decode_bbox_pred = bbox_results['decode_bbox_pred']
  235. single_stage_loss = self.bbox_head[stage].loss(
  236. cls_score.view(-1, cls_score.size(-1)),
  237. decode_bbox_pred.view(-1, 4),
  238. *bbox_targets,
  239. imgs_whwh=imgs_whwh)
  240. if self.with_mask:
  241. mask_results = self._mask_forward_train(
  242. stage, x, bbox_results['attn_feats'], sampling_results,
  243. gt_masks, self.train_cfg[stage])
  244. single_stage_loss['loss_mask'] = mask_results['loss_mask']
  245. for key, value in single_stage_loss.items():
  246. all_stage_loss[f'stage{stage}_{key}'] = value * \
  247. self.stage_loss_weights[stage]
  248. object_feats = bbox_results['object_feats']
  249. return all_stage_loss
  250. def simple_test(self,
  251. x,
  252. proposal_boxes,
  253. proposal_features,
  254. img_metas,
  255. imgs_whwh,
  256. rescale=False):
  257. """Test without augmentation.
  258. Args:
  259. x (list[Tensor]): list of multi-level img features.
  260. proposal_boxes (Tensor): Decoded proposal bboxes, has shape
  261. (batch_size, num_proposals, 4)
  262. proposal_features (Tensor): Expanded proposal
  263. features, has shape
  264. (batch_size, num_proposals, proposal_feature_channel)
  265. img_metas (dict): meta information of images.
  266. imgs_whwh (Tensor): Tensor with shape (batch_size, 4),
  267. the dimension means
  268. [img_width,img_height, img_width, img_height].
  269. rescale (bool): If True, return boxes in original image
  270. space. Defaults to False.
  271. Returns:
  272. list[list[np.ndarray]] or list[tuple]: When no mask branch,
  273. it is bbox results of each image and classes with type
  274. `list[list[np.ndarray]]`. The outer list
  275. corresponds to each image. The inner list
  276. corresponds to each class. When the model has a mask branch,
  277. it is a list[tuple] that contains bbox results and mask results.
  278. The outer list corresponds to each image, and first element
  279. of tuple is bbox results, second element is mask results.
  280. """
  281. assert self.with_bbox, 'Bbox head must be implemented.'
  282. # Decode initial proposals
  283. num_imgs = len(img_metas)
  284. proposal_list = [proposal_boxes[i] for i in range(num_imgs)]
  285. ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
  286. scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
  287. object_feats = proposal_features
  288. if all([proposal.shape[0] == 0 for proposal in proposal_list]):
  289. # There is no proposal in the whole batch
  290. bbox_results = [[
  291. np.zeros((0, 5), dtype=np.float32)
  292. for i in range(self.bbox_head[-1].num_classes)
  293. ]] * num_imgs
  294. return bbox_results
  295. for stage in range(self.num_stages):
  296. rois = bbox2roi(proposal_list)
  297. bbox_results = self._bbox_forward(stage, x, rois, object_feats,
  298. img_metas)
  299. object_feats = bbox_results['object_feats']
  300. cls_score = bbox_results['cls_score']
  301. proposal_list = bbox_results['detach_proposal_list']
  302. if self.with_mask:
  303. rois = bbox2roi(proposal_list)
  304. mask_results = self._mask_forward(stage, x, rois,
  305. bbox_results['attn_feats'])
  306. mask_results['mask_pred'] = mask_results['mask_pred'].reshape(
  307. num_imgs, -1, *mask_results['mask_pred'].size()[1:])
  308. num_classes = self.bbox_head[-1].num_classes
  309. det_bboxes = []
  310. det_labels = []
  311. if self.bbox_head[-1].loss_cls.use_sigmoid:
  312. cls_score = cls_score.sigmoid()
  313. else:
  314. cls_score = cls_score.softmax(-1)[..., :-1]
  315. for img_id in range(num_imgs):
  316. cls_score_per_img = cls_score[img_id]
  317. scores_per_img, topk_indices = cls_score_per_img.flatten(
  318. 0, 1).topk(
  319. self.test_cfg.max_per_img, sorted=False)
  320. labels_per_img = topk_indices % num_classes
  321. bbox_pred_per_img = proposal_list[img_id][topk_indices //
  322. num_classes]
  323. if rescale:
  324. scale_factor = img_metas[img_id]['scale_factor']
  325. bbox_pred_per_img /= bbox_pred_per_img.new_tensor(scale_factor)
  326. det_bboxes.append(
  327. torch.cat([bbox_pred_per_img, scores_per_img[:, None]], dim=1))
  328. det_labels.append(labels_per_img)
  329. bbox_results = [
  330. bbox2result(det_bboxes[i], det_labels[i], num_classes)
  331. for i in range(num_imgs)
  332. ]
  333. if self.with_mask:
  334. if rescale and not isinstance(scale_factors[0], float):
  335. scale_factors = [
  336. torch.from_numpy(scale_factor).to(det_bboxes[0].device)
  337. for scale_factor in scale_factors
  338. ]
  339. _bboxes = [
  340. det_bboxes[i][:, :4] *
  341. scale_factors[i] if rescale else det_bboxes[i][:, :4]
  342. for i in range(len(det_bboxes))
  343. ]
  344. segm_results = []
  345. mask_pred = mask_results['mask_pred']
  346. for img_id in range(num_imgs):
  347. mask_pred_per_img = mask_pred[img_id].flatten(0,
  348. 1)[topk_indices]
  349. mask_pred_per_img = mask_pred_per_img[:, None, ...].repeat(
  350. 1, num_classes, 1, 1)
  351. segm_result = self.mask_head[-1].get_seg_masks(
  352. mask_pred_per_img, _bboxes[img_id], det_labels[img_id],
  353. self.test_cfg, ori_shapes[img_id], scale_factors[img_id],
  354. rescale)
  355. segm_results.append(segm_result)
  356. if self.with_mask:
  357. results = list(zip(bbox_results, segm_results))
  358. else:
  359. results = bbox_results
  360. return results
  361. def aug_test(self, features, proposal_list, img_metas, rescale=False):
  362. raise NotImplementedError(
  363. 'Sparse R-CNN and QueryInst does not support `aug_test`')
  364. def forward_dummy(self, x, proposal_boxes, proposal_features, img_metas):
  365. """Dummy forward function when do the flops computing."""
  366. all_stage_bbox_results = []
  367. proposal_list = [proposal_boxes[i] for i in range(len(proposal_boxes))]
  368. object_feats = proposal_features
  369. if self.with_bbox:
  370. for stage in range(self.num_stages):
  371. rois = bbox2roi(proposal_list)
  372. bbox_results = self._bbox_forward(stage, x, rois, object_feats,
  373. img_metas)
  374. all_stage_bbox_results.append((bbox_results, ))
  375. proposal_list = bbox_results['detach_proposal_list']
  376. object_feats = bbox_results['object_feats']
  377. if self.with_mask:
  378. rois = bbox2roi(proposal_list)
  379. mask_results = self._mask_forward(
  380. stage, x, rois, bbox_results['attn_feats'])
  381. all_stage_bbox_results[-1] += (mask_results, )
  382. return all_stage_bbox_results

No Description

Contributors (3)