You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

htc_roi_head.py 28 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. import torch.nn.functional as F
  5. from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes,
  6. merge_aug_masks, multiclass_nms)
  7. from ..builder import HEADS, build_head, build_roi_extractor
  8. from ..utils.brick_wrappers import adaptive_avg_pool2d
  9. from .cascade_roi_head import CascadeRoIHead
  10. @HEADS.register_module()
  11. class HybridTaskCascadeRoIHead(CascadeRoIHead):
  12. """Hybrid task cascade roi head including one bbox head and one mask head.
  13. https://arxiv.org/abs/1901.07518
  14. """
  15. def __init__(self,
  16. num_stages,
  17. stage_loss_weights,
  18. semantic_roi_extractor=None,
  19. semantic_head=None,
  20. semantic_fusion=('bbox', 'mask'),
  21. interleaved=True,
  22. mask_info_flow=True,
  23. **kwargs):
  24. super(HybridTaskCascadeRoIHead,
  25. self).__init__(num_stages, stage_loss_weights, **kwargs)
  26. assert self.with_bbox
  27. assert not self.with_shared_head # shared head is not supported
  28. if semantic_head is not None:
  29. self.semantic_roi_extractor = build_roi_extractor(
  30. semantic_roi_extractor)
  31. self.semantic_head = build_head(semantic_head)
  32. self.semantic_fusion = semantic_fusion
  33. self.interleaved = interleaved
  34. self.mask_info_flow = mask_info_flow
  35. @property
  36. def with_semantic(self):
  37. """bool: whether the head has semantic head"""
  38. if hasattr(self, 'semantic_head') and self.semantic_head is not None:
  39. return True
  40. else:
  41. return False
  42. def forward_dummy(self, x, proposals):
  43. """Dummy forward function."""
  44. outs = ()
  45. # semantic head
  46. if self.with_semantic:
  47. _, semantic_feat = self.semantic_head(x)
  48. else:
  49. semantic_feat = None
  50. # bbox heads
  51. rois = bbox2roi([proposals])
  52. for i in range(self.num_stages):
  53. bbox_results = self._bbox_forward(
  54. i, x, rois, semantic_feat=semantic_feat)
  55. outs = outs + (bbox_results['cls_score'],
  56. bbox_results['bbox_pred'])
  57. # mask heads
  58. if self.with_mask:
  59. mask_rois = rois[:100]
  60. mask_roi_extractor = self.mask_roi_extractor[-1]
  61. mask_feats = mask_roi_extractor(
  62. x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
  63. if self.with_semantic and 'mask' in self.semantic_fusion:
  64. mask_semantic_feat = self.semantic_roi_extractor(
  65. [semantic_feat], mask_rois)
  66. mask_feats += mask_semantic_feat
  67. last_feat = None
  68. for i in range(self.num_stages):
  69. mask_head = self.mask_head[i]
  70. if self.mask_info_flow:
  71. mask_pred, last_feat = mask_head(mask_feats, last_feat)
  72. else:
  73. mask_pred = mask_head(mask_feats)
  74. outs = outs + (mask_pred, )
  75. return outs
  76. def _bbox_forward_train(self,
  77. stage,
  78. x,
  79. sampling_results,
  80. gt_bboxes,
  81. gt_labels,
  82. rcnn_train_cfg,
  83. semantic_feat=None):
  84. """Run forward function and calculate loss for box head in training."""
  85. bbox_head = self.bbox_head[stage]
  86. rois = bbox2roi([res.bboxes for res in sampling_results])
  87. bbox_results = self._bbox_forward(
  88. stage, x, rois, semantic_feat=semantic_feat)
  89. bbox_targets = bbox_head.get_targets(sampling_results, gt_bboxes,
  90. gt_labels, rcnn_train_cfg)
  91. loss_bbox = bbox_head.loss(bbox_results['cls_score'],
  92. bbox_results['bbox_pred'], rois,
  93. *bbox_targets)
  94. bbox_results.update(
  95. loss_bbox=loss_bbox,
  96. rois=rois,
  97. bbox_targets=bbox_targets,
  98. )
  99. return bbox_results
  100. def _mask_forward_train(self,
  101. stage,
  102. x,
  103. sampling_results,
  104. gt_masks,
  105. rcnn_train_cfg,
  106. semantic_feat=None):
  107. """Run forward function and calculate loss for mask head in
  108. training."""
  109. mask_roi_extractor = self.mask_roi_extractor[stage]
  110. mask_head = self.mask_head[stage]
  111. pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
  112. mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs],
  113. pos_rois)
  114. # semantic feature fusion
  115. # element-wise sum for original features and pooled semantic features
  116. if self.with_semantic and 'mask' in self.semantic_fusion:
  117. mask_semantic_feat = self.semantic_roi_extractor([semantic_feat],
  118. pos_rois)
  119. if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
  120. mask_semantic_feat = F.adaptive_avg_pool2d(
  121. mask_semantic_feat, mask_feats.shape[-2:])
  122. mask_feats += mask_semantic_feat
  123. # mask information flow
  124. # forward all previous mask heads to obtain last_feat, and fuse it
  125. # with the normal mask feature
  126. if self.mask_info_flow:
  127. last_feat = None
  128. for i in range(stage):
  129. last_feat = self.mask_head[i](
  130. mask_feats, last_feat, return_logits=False)
  131. mask_pred = mask_head(mask_feats, last_feat, return_feat=False)
  132. else:
  133. mask_pred = mask_head(mask_feats, return_feat=False)
  134. mask_targets = mask_head.get_targets(sampling_results, gt_masks,
  135. rcnn_train_cfg)
  136. pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
  137. loss_mask = mask_head.loss(mask_pred, mask_targets, pos_labels)
  138. mask_results = dict(loss_mask=loss_mask)
  139. return mask_results
  140. def _bbox_forward(self, stage, x, rois, semantic_feat=None):
  141. """Box head forward function used in both training and testing."""
  142. bbox_roi_extractor = self.bbox_roi_extractor[stage]
  143. bbox_head = self.bbox_head[stage]
  144. bbox_feats = bbox_roi_extractor(
  145. x[:len(bbox_roi_extractor.featmap_strides)], rois)
  146. if self.with_semantic and 'bbox' in self.semantic_fusion:
  147. bbox_semantic_feat = self.semantic_roi_extractor([semantic_feat],
  148. rois)
  149. if bbox_semantic_feat.shape[-2:] != bbox_feats.shape[-2:]:
  150. bbox_semantic_feat = adaptive_avg_pool2d(
  151. bbox_semantic_feat, bbox_feats.shape[-2:])
  152. bbox_feats += bbox_semantic_feat
  153. cls_score, bbox_pred = bbox_head(bbox_feats)
  154. bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred)
  155. return bbox_results
  156. def _mask_forward_test(self, stage, x, bboxes, semantic_feat=None):
  157. """Mask head forward function for testing."""
  158. mask_roi_extractor = self.mask_roi_extractor[stage]
  159. mask_head = self.mask_head[stage]
  160. mask_rois = bbox2roi([bboxes])
  161. mask_feats = mask_roi_extractor(
  162. x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
  163. if self.with_semantic and 'mask' in self.semantic_fusion:
  164. mask_semantic_feat = self.semantic_roi_extractor([semantic_feat],
  165. mask_rois)
  166. if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
  167. mask_semantic_feat = F.adaptive_avg_pool2d(
  168. mask_semantic_feat, mask_feats.shape[-2:])
  169. mask_feats += mask_semantic_feat
  170. if self.mask_info_flow:
  171. last_feat = None
  172. last_pred = None
  173. for i in range(stage):
  174. mask_pred, last_feat = self.mask_head[i](mask_feats, last_feat)
  175. if last_pred is not None:
  176. mask_pred = mask_pred + last_pred
  177. last_pred = mask_pred
  178. mask_pred = mask_head(mask_feats, last_feat, return_feat=False)
  179. if last_pred is not None:
  180. mask_pred = mask_pred + last_pred
  181. else:
  182. mask_pred = mask_head(mask_feats)
  183. return mask_pred
  184. def forward_train(self,
  185. x,
  186. img_metas,
  187. proposal_list,
  188. gt_bboxes,
  189. gt_labels,
  190. gt_bboxes_ignore=None,
  191. gt_masks=None,
  192. gt_semantic_seg=None):
  193. """
  194. Args:
  195. x (list[Tensor]): list of multi-level img features.
  196. img_metas (list[dict]): list of image info dict where each dict
  197. has: 'img_shape', 'scale_factor', 'flip', and may also contain
  198. 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
  199. For details on the values of these keys see
  200. `mmdet/datasets/pipelines/formatting.py:Collect`.
  201. proposal_list (list[Tensors]): list of region proposals.
  202. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  203. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  204. gt_labels (list[Tensor]): class indices corresponding to each box
  205. gt_bboxes_ignore (None, list[Tensor]): specify which bounding
  206. boxes can be ignored when computing the loss.
  207. gt_masks (None, Tensor) : true segmentation masks for each box
  208. used if the architecture supports a segmentation task.
  209. gt_semantic_seg (None, list[Tensor]): semantic segmentation masks
  210. used if the architecture supports semantic segmentation task.
  211. Returns:
  212. dict[str, Tensor]: a dictionary of loss components
  213. """
  214. # semantic segmentation part
  215. # 2 outputs: segmentation prediction and embedded features
  216. losses = dict()
  217. if self.with_semantic:
  218. semantic_pred, semantic_feat = self.semantic_head(x)
  219. loss_seg = self.semantic_head.loss(semantic_pred, gt_semantic_seg)
  220. losses['loss_semantic_seg'] = loss_seg
  221. else:
  222. semantic_feat = None
  223. for i in range(self.num_stages):
  224. self.current_stage = i
  225. rcnn_train_cfg = self.train_cfg[i]
  226. lw = self.stage_loss_weights[i]
  227. # assign gts and sample proposals
  228. sampling_results = []
  229. bbox_assigner = self.bbox_assigner[i]
  230. bbox_sampler = self.bbox_sampler[i]
  231. num_imgs = len(img_metas)
  232. if gt_bboxes_ignore is None:
  233. gt_bboxes_ignore = [None for _ in range(num_imgs)]
  234. for j in range(num_imgs):
  235. assign_result = bbox_assigner.assign(proposal_list[j],
  236. gt_bboxes[j],
  237. gt_bboxes_ignore[j],
  238. gt_labels[j])
  239. sampling_result = bbox_sampler.sample(
  240. assign_result,
  241. proposal_list[j],
  242. gt_bboxes[j],
  243. gt_labels[j],
  244. feats=[lvl_feat[j][None] for lvl_feat in x])
  245. sampling_results.append(sampling_result)
  246. # bbox head forward and loss
  247. bbox_results = \
  248. self._bbox_forward_train(
  249. i, x, sampling_results, gt_bboxes, gt_labels,
  250. rcnn_train_cfg, semantic_feat)
  251. roi_labels = bbox_results['bbox_targets'][0]
  252. for name, value in bbox_results['loss_bbox'].items():
  253. losses[f's{i}.{name}'] = (
  254. value * lw if 'loss' in name else value)
  255. # mask head forward and loss
  256. if self.with_mask:
  257. # interleaved execution: use regressed bboxes by the box branch
  258. # to train the mask branch
  259. if self.interleaved:
  260. pos_is_gts = [res.pos_is_gt for res in sampling_results]
  261. with torch.no_grad():
  262. proposal_list = self.bbox_head[i].refine_bboxes(
  263. bbox_results['rois'], roi_labels,
  264. bbox_results['bbox_pred'], pos_is_gts, img_metas)
  265. # re-assign and sample 512 RoIs from 512 RoIs
  266. sampling_results = []
  267. for j in range(num_imgs):
  268. assign_result = bbox_assigner.assign(
  269. proposal_list[j], gt_bboxes[j],
  270. gt_bboxes_ignore[j], gt_labels[j])
  271. sampling_result = bbox_sampler.sample(
  272. assign_result,
  273. proposal_list[j],
  274. gt_bboxes[j],
  275. gt_labels[j],
  276. feats=[lvl_feat[j][None] for lvl_feat in x])
  277. sampling_results.append(sampling_result)
  278. mask_results = self._mask_forward_train(
  279. i, x, sampling_results, gt_masks, rcnn_train_cfg,
  280. semantic_feat)
  281. for name, value in mask_results['loss_mask'].items():
  282. losses[f's{i}.{name}'] = (
  283. value * lw if 'loss' in name else value)
  284. # refine bboxes (same as Cascade R-CNN)
  285. if i < self.num_stages - 1 and not self.interleaved:
  286. pos_is_gts = [res.pos_is_gt for res in sampling_results]
  287. with torch.no_grad():
  288. proposal_list = self.bbox_head[i].refine_bboxes(
  289. bbox_results['rois'], roi_labels,
  290. bbox_results['bbox_pred'], pos_is_gts, img_metas)
  291. return losses
  292. def simple_test(self, x, proposal_list, img_metas, rescale=False):
  293. """Test without augmentation.
  294. Args:
  295. x (tuple[Tensor]): Features from upstream network. Each
  296. has shape (batch_size, c, h, w).
  297. proposal_list (list(Tensor)): Proposals from rpn head.
  298. Each has shape (num_proposals, 5), last dimension
  299. 5 represent (x1, y1, x2, y2, score).
  300. img_metas (list[dict]): Meta information of images.
  301. rescale (bool): Whether to rescale the results to
  302. the original image. Default: True.
  303. Returns:
  304. list[list[np.ndarray]] or list[tuple]: When no mask branch,
  305. it is bbox results of each image and classes with type
  306. `list[list[np.ndarray]]`. The outer list
  307. corresponds to each image. The inner list
  308. corresponds to each class. When the model has mask branch,
  309. it contains bbox results and mask results.
  310. The outer list corresponds to each image, and first element
  311. of tuple is bbox results, second element is mask results.
  312. """
  313. if self.with_semantic:
  314. _, semantic_feat = self.semantic_head(x)
  315. else:
  316. semantic_feat = None
  317. num_imgs = len(proposal_list)
  318. img_shapes = tuple(meta['img_shape'] for meta in img_metas)
  319. ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
  320. scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
  321. # "ms" in variable names means multi-stage
  322. ms_bbox_result = {}
  323. ms_segm_result = {}
  324. ms_scores = []
  325. rcnn_test_cfg = self.test_cfg
  326. rois = bbox2roi(proposal_list)
  327. if rois.shape[0] == 0:
  328. # There is no proposal in the whole batch
  329. bbox_results = [[
  330. np.zeros((0, 5), dtype=np.float32)
  331. for _ in range(self.bbox_head[-1].num_classes)
  332. ]] * num_imgs
  333. if self.with_mask:
  334. mask_classes = self.mask_head[-1].num_classes
  335. segm_results = [[[] for _ in range(mask_classes)]
  336. for _ in range(num_imgs)]
  337. results = list(zip(bbox_results, segm_results))
  338. else:
  339. results = bbox_results
  340. return results
  341. for i in range(self.num_stages):
  342. bbox_head = self.bbox_head[i]
  343. bbox_results = self._bbox_forward(
  344. i, x, rois, semantic_feat=semantic_feat)
  345. # split batch bbox prediction back to each image
  346. cls_score = bbox_results['cls_score']
  347. bbox_pred = bbox_results['bbox_pred']
  348. num_proposals_per_img = tuple(len(p) for p in proposal_list)
  349. rois = rois.split(num_proposals_per_img, 0)
  350. cls_score = cls_score.split(num_proposals_per_img, 0)
  351. bbox_pred = bbox_pred.split(num_proposals_per_img, 0)
  352. ms_scores.append(cls_score)
  353. if i < self.num_stages - 1:
  354. refine_rois_list = []
  355. for j in range(num_imgs):
  356. if rois[j].shape[0] > 0:
  357. bbox_label = cls_score[j][:, :-1].argmax(dim=1)
  358. refine_rois = bbox_head.regress_by_class(
  359. rois[j], bbox_label, bbox_pred[j], img_metas[j])
  360. refine_rois_list.append(refine_rois)
  361. rois = torch.cat(refine_rois_list)
  362. # average scores of each image by stages
  363. cls_score = [
  364. sum([score[i] for score in ms_scores]) / float(len(ms_scores))
  365. for i in range(num_imgs)
  366. ]
  367. # apply bbox post-processing to each image individually
  368. det_bboxes = []
  369. det_labels = []
  370. for i in range(num_imgs):
  371. det_bbox, det_label = self.bbox_head[-1].get_bboxes(
  372. rois[i],
  373. cls_score[i],
  374. bbox_pred[i],
  375. img_shapes[i],
  376. scale_factors[i],
  377. rescale=rescale,
  378. cfg=rcnn_test_cfg)
  379. det_bboxes.append(det_bbox)
  380. det_labels.append(det_label)
  381. bbox_result = [
  382. bbox2result(det_bboxes[i], det_labels[i],
  383. self.bbox_head[-1].num_classes)
  384. for i in range(num_imgs)
  385. ]
  386. ms_bbox_result['ensemble'] = bbox_result
  387. if self.with_mask:
  388. if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
  389. mask_classes = self.mask_head[-1].num_classes
  390. segm_results = [[[] for _ in range(mask_classes)]
  391. for _ in range(num_imgs)]
  392. else:
  393. if rescale and not isinstance(scale_factors[0], float):
  394. scale_factors = [
  395. torch.from_numpy(scale_factor).to(det_bboxes[0].device)
  396. for scale_factor in scale_factors
  397. ]
  398. _bboxes = [
  399. det_bboxes[i][:, :4] *
  400. scale_factors[i] if rescale else det_bboxes[i]
  401. for i in range(num_imgs)
  402. ]
  403. mask_rois = bbox2roi(_bboxes)
  404. aug_masks = []
  405. mask_roi_extractor = self.mask_roi_extractor[-1]
  406. mask_feats = mask_roi_extractor(
  407. x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
  408. if self.with_semantic and 'mask' in self.semantic_fusion:
  409. mask_semantic_feat = self.semantic_roi_extractor(
  410. [semantic_feat], mask_rois)
  411. mask_feats += mask_semantic_feat
  412. last_feat = None
  413. num_bbox_per_img = tuple(len(_bbox) for _bbox in _bboxes)
  414. for i in range(self.num_stages):
  415. mask_head = self.mask_head[i]
  416. if self.mask_info_flow:
  417. mask_pred, last_feat = mask_head(mask_feats, last_feat)
  418. else:
  419. mask_pred = mask_head(mask_feats)
  420. # split batch mask prediction back to each image
  421. mask_pred = mask_pred.split(num_bbox_per_img, 0)
  422. aug_masks.append(
  423. [mask.sigmoid().cpu().numpy() for mask in mask_pred])
  424. # apply mask post-processing to each image individually
  425. segm_results = []
  426. for i in range(num_imgs):
  427. if det_bboxes[i].shape[0] == 0:
  428. segm_results.append(
  429. [[]
  430. for _ in range(self.mask_head[-1].num_classes)])
  431. else:
  432. aug_mask = [mask[i] for mask in aug_masks]
  433. merged_mask = merge_aug_masks(
  434. aug_mask, [[img_metas[i]]] * self.num_stages,
  435. rcnn_test_cfg)
  436. segm_result = self.mask_head[-1].get_seg_masks(
  437. merged_mask, _bboxes[i], det_labels[i],
  438. rcnn_test_cfg, ori_shapes[i], scale_factors[i],
  439. rescale)
  440. segm_results.append(segm_result)
  441. ms_segm_result['ensemble'] = segm_results
  442. if self.with_mask:
  443. results = list(
  444. zip(ms_bbox_result['ensemble'], ms_segm_result['ensemble']))
  445. else:
  446. results = ms_bbox_result['ensemble']
  447. return results
  448. def aug_test(self, img_feats, proposal_list, img_metas, rescale=False):
  449. """Test with augmentations.
  450. If rescale is False, then returned bboxes and masks will fit the scale
  451. of imgs[0].
  452. """
  453. if self.with_semantic:
  454. semantic_feats = [
  455. self.semantic_head(feat)[1] for feat in img_feats
  456. ]
  457. else:
  458. semantic_feats = [None] * len(img_metas)
  459. rcnn_test_cfg = self.test_cfg
  460. aug_bboxes = []
  461. aug_scores = []
  462. for x, img_meta, semantic in zip(img_feats, img_metas, semantic_feats):
  463. # only one image in the batch
  464. img_shape = img_meta[0]['img_shape']
  465. scale_factor = img_meta[0]['scale_factor']
  466. flip = img_meta[0]['flip']
  467. flip_direction = img_meta[0]['flip_direction']
  468. proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
  469. scale_factor, flip, flip_direction)
  470. # "ms" in variable names means multi-stage
  471. ms_scores = []
  472. rois = bbox2roi([proposals])
  473. if rois.shape[0] == 0:
  474. # There is no proposal in the single image
  475. aug_bboxes.append(rois.new_zeros(0, 4))
  476. aug_scores.append(rois.new_zeros(0, 1))
  477. continue
  478. for i in range(self.num_stages):
  479. bbox_head = self.bbox_head[i]
  480. bbox_results = self._bbox_forward(
  481. i, x, rois, semantic_feat=semantic)
  482. ms_scores.append(bbox_results['cls_score'])
  483. if i < self.num_stages - 1:
  484. bbox_label = bbox_results['cls_score'].argmax(dim=1)
  485. rois = bbox_head.regress_by_class(
  486. rois, bbox_label, bbox_results['bbox_pred'],
  487. img_meta[0])
  488. cls_score = sum(ms_scores) / float(len(ms_scores))
  489. bboxes, scores = self.bbox_head[-1].get_bboxes(
  490. rois,
  491. cls_score,
  492. bbox_results['bbox_pred'],
  493. img_shape,
  494. scale_factor,
  495. rescale=False,
  496. cfg=None)
  497. aug_bboxes.append(bboxes)
  498. aug_scores.append(scores)
  499. # after merging, bboxes will be rescaled to the original image size
  500. merged_bboxes, merged_scores = merge_aug_bboxes(
  501. aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
  502. det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
  503. rcnn_test_cfg.score_thr,
  504. rcnn_test_cfg.nms,
  505. rcnn_test_cfg.max_per_img)
  506. bbox_result = bbox2result(det_bboxes, det_labels,
  507. self.bbox_head[-1].num_classes)
  508. if self.with_mask:
  509. if det_bboxes.shape[0] == 0:
  510. segm_result = [[]
  511. for _ in range(self.mask_head[-1].num_classes)]
  512. else:
  513. aug_masks = []
  514. aug_img_metas = []
  515. for x, img_meta, semantic in zip(img_feats, img_metas,
  516. semantic_feats):
  517. img_shape = img_meta[0]['img_shape']
  518. scale_factor = img_meta[0]['scale_factor']
  519. flip = img_meta[0]['flip']
  520. flip_direction = img_meta[0]['flip_direction']
  521. _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
  522. scale_factor, flip, flip_direction)
  523. mask_rois = bbox2roi([_bboxes])
  524. mask_feats = self.mask_roi_extractor[-1](
  525. x[:len(self.mask_roi_extractor[-1].featmap_strides)],
  526. mask_rois)
  527. if self.with_semantic:
  528. semantic_feat = semantic
  529. mask_semantic_feat = self.semantic_roi_extractor(
  530. [semantic_feat], mask_rois)
  531. if mask_semantic_feat.shape[-2:] != mask_feats.shape[
  532. -2:]:
  533. mask_semantic_feat = F.adaptive_avg_pool2d(
  534. mask_semantic_feat, mask_feats.shape[-2:])
  535. mask_feats += mask_semantic_feat
  536. last_feat = None
  537. for i in range(self.num_stages):
  538. mask_head = self.mask_head[i]
  539. if self.mask_info_flow:
  540. mask_pred, last_feat = mask_head(
  541. mask_feats, last_feat)
  542. else:
  543. mask_pred = mask_head(mask_feats)
  544. aug_masks.append(mask_pred.sigmoid().cpu().numpy())
  545. aug_img_metas.append(img_meta)
  546. merged_masks = merge_aug_masks(aug_masks, aug_img_metas,
  547. self.test_cfg)
  548. ori_shape = img_metas[0][0]['ori_shape']
  549. segm_result = self.mask_head[-1].get_seg_masks(
  550. merged_masks,
  551. det_bboxes,
  552. det_labels,
  553. rcnn_test_cfg,
  554. ori_shape,
  555. scale_factor=1.0,
  556. rescale=False)
  557. return [(bbox_result, segm_result)]
  558. else:
  559. return [bbox_result]

No Description

Contributors (3)