You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

scnet_roi_head.py 26 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. import torch.nn.functional as F
  5. from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes,
  6. merge_aug_masks, multiclass_nms)
  7. from ..builder import HEADS, build_head, build_roi_extractor
  8. from ..utils.brick_wrappers import adaptive_avg_pool2d
  9. from .cascade_roi_head import CascadeRoIHead
  10. @HEADS.register_module()
  11. class SCNetRoIHead(CascadeRoIHead):
  12. """RoIHead for `SCNet <https://arxiv.org/abs/2012.10150>`_.
  13. Args:
  14. num_stages (int): number of cascade stages.
  15. stage_loss_weights (list): loss weight of cascade stages.
  16. semantic_roi_extractor (dict): config to init semantic roi extractor.
  17. semantic_head (dict): config to init semantic head.
  18. feat_relay_head (dict): config to init feature_relay_head.
  19. glbctx_head (dict): config to init global context head.
  20. """
  21. def __init__(self,
  22. num_stages,
  23. stage_loss_weights,
  24. semantic_roi_extractor=None,
  25. semantic_head=None,
  26. feat_relay_head=None,
  27. glbctx_head=None,
  28. **kwargs):
  29. super(SCNetRoIHead, self).__init__(num_stages, stage_loss_weights,
  30. **kwargs)
  31. assert self.with_bbox and self.with_mask
  32. assert not self.with_shared_head # shared head is not supported
  33. if semantic_head is not None:
  34. self.semantic_roi_extractor = build_roi_extractor(
  35. semantic_roi_extractor)
  36. self.semantic_head = build_head(semantic_head)
  37. if feat_relay_head is not None:
  38. self.feat_relay_head = build_head(feat_relay_head)
  39. if glbctx_head is not None:
  40. self.glbctx_head = build_head(glbctx_head)
  41. def init_mask_head(self, mask_roi_extractor, mask_head):
  42. """Initialize ``mask_head``"""
  43. if mask_roi_extractor is not None:
  44. self.mask_roi_extractor = build_roi_extractor(mask_roi_extractor)
  45. self.mask_head = build_head(mask_head)
  46. @property
  47. def with_semantic(self):
  48. """bool: whether the head has semantic head"""
  49. return hasattr(self,
  50. 'semantic_head') and self.semantic_head is not None
  51. @property
  52. def with_feat_relay(self):
  53. """bool: whether the head has feature relay head"""
  54. return (hasattr(self, 'feat_relay_head')
  55. and self.feat_relay_head is not None)
  56. @property
  57. def with_glbctx(self):
  58. """bool: whether the head has global context head"""
  59. return hasattr(self, 'glbctx_head') and self.glbctx_head is not None
  60. def _fuse_glbctx(self, roi_feats, glbctx_feat, rois):
  61. """Fuse global context feats with roi feats."""
  62. assert roi_feats.size(0) == rois.size(0)
  63. img_inds = torch.unique(rois[:, 0].cpu(), sorted=True).long()
  64. fused_feats = torch.zeros_like(roi_feats)
  65. for img_id in img_inds:
  66. inds = (rois[:, 0] == img_id.item())
  67. fused_feats[inds] = roi_feats[inds] + glbctx_feat[img_id]
  68. return fused_feats
  69. def _slice_pos_feats(self, feats, sampling_results):
  70. """Get features from pos rois."""
  71. num_rois = [res.bboxes.size(0) for res in sampling_results]
  72. num_pos_rois = [res.pos_bboxes.size(0) for res in sampling_results]
  73. inds = torch.zeros(sum(num_rois), dtype=torch.bool)
  74. start = 0
  75. for i in range(len(num_rois)):
  76. start = 0 if i == 0 else start + num_rois[i - 1]
  77. stop = start + num_pos_rois[i]
  78. inds[start:stop] = 1
  79. sliced_feats = feats[inds]
  80. return sliced_feats
  81. def _bbox_forward(self,
  82. stage,
  83. x,
  84. rois,
  85. semantic_feat=None,
  86. glbctx_feat=None):
  87. """Box head forward function used in both training and testing."""
  88. bbox_roi_extractor = self.bbox_roi_extractor[stage]
  89. bbox_head = self.bbox_head[stage]
  90. bbox_feats = bbox_roi_extractor(
  91. x[:len(bbox_roi_extractor.featmap_strides)], rois)
  92. if self.with_semantic and semantic_feat is not None:
  93. bbox_semantic_feat = self.semantic_roi_extractor([semantic_feat],
  94. rois)
  95. if bbox_semantic_feat.shape[-2:] != bbox_feats.shape[-2:]:
  96. bbox_semantic_feat = adaptive_avg_pool2d(
  97. bbox_semantic_feat, bbox_feats.shape[-2:])
  98. bbox_feats += bbox_semantic_feat
  99. if self.with_glbctx and glbctx_feat is not None:
  100. bbox_feats = self._fuse_glbctx(bbox_feats, glbctx_feat, rois)
  101. cls_score, bbox_pred, relayed_feat = bbox_head(
  102. bbox_feats, return_shared_feat=True)
  103. bbox_results = dict(
  104. cls_score=cls_score,
  105. bbox_pred=bbox_pred,
  106. relayed_feat=relayed_feat)
  107. return bbox_results
  108. def _mask_forward(self,
  109. x,
  110. rois,
  111. semantic_feat=None,
  112. glbctx_feat=None,
  113. relayed_feat=None):
  114. """Mask head forward function used in both training and testing."""
  115. mask_feats = self.mask_roi_extractor(
  116. x[:self.mask_roi_extractor.num_inputs], rois)
  117. if self.with_semantic and semantic_feat is not None:
  118. mask_semantic_feat = self.semantic_roi_extractor([semantic_feat],
  119. rois)
  120. if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
  121. mask_semantic_feat = F.adaptive_avg_pool2d(
  122. mask_semantic_feat, mask_feats.shape[-2:])
  123. mask_feats += mask_semantic_feat
  124. if self.with_glbctx and glbctx_feat is not None:
  125. mask_feats = self._fuse_glbctx(mask_feats, glbctx_feat, rois)
  126. if self.with_feat_relay and relayed_feat is not None:
  127. mask_feats = mask_feats + relayed_feat
  128. mask_pred = self.mask_head(mask_feats)
  129. mask_results = dict(mask_pred=mask_pred)
  130. return mask_results
  131. def _bbox_forward_train(self,
  132. stage,
  133. x,
  134. sampling_results,
  135. gt_bboxes,
  136. gt_labels,
  137. rcnn_train_cfg,
  138. semantic_feat=None,
  139. glbctx_feat=None):
  140. """Run forward function and calculate loss for box head in training."""
  141. bbox_head = self.bbox_head[stage]
  142. rois = bbox2roi([res.bboxes for res in sampling_results])
  143. bbox_results = self._bbox_forward(
  144. stage,
  145. x,
  146. rois,
  147. semantic_feat=semantic_feat,
  148. glbctx_feat=glbctx_feat)
  149. bbox_targets = bbox_head.get_targets(sampling_results, gt_bboxes,
  150. gt_labels, rcnn_train_cfg)
  151. loss_bbox = bbox_head.loss(bbox_results['cls_score'],
  152. bbox_results['bbox_pred'], rois,
  153. *bbox_targets)
  154. bbox_results.update(
  155. loss_bbox=loss_bbox, rois=rois, bbox_targets=bbox_targets)
  156. return bbox_results
  157. def _mask_forward_train(self,
  158. x,
  159. sampling_results,
  160. gt_masks,
  161. rcnn_train_cfg,
  162. semantic_feat=None,
  163. glbctx_feat=None,
  164. relayed_feat=None):
  165. """Run forward function and calculate loss for mask head in
  166. training."""
  167. pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
  168. mask_results = self._mask_forward(
  169. x,
  170. pos_rois,
  171. semantic_feat=semantic_feat,
  172. glbctx_feat=glbctx_feat,
  173. relayed_feat=relayed_feat)
  174. mask_targets = self.mask_head.get_targets(sampling_results, gt_masks,
  175. rcnn_train_cfg)
  176. pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
  177. loss_mask = self.mask_head.loss(mask_results['mask_pred'],
  178. mask_targets, pos_labels)
  179. mask_results = loss_mask
  180. return mask_results
  181. def forward_train(self,
  182. x,
  183. img_metas,
  184. proposal_list,
  185. gt_bboxes,
  186. gt_labels,
  187. gt_bboxes_ignore=None,
  188. gt_masks=None,
  189. gt_semantic_seg=None):
  190. """
  191. Args:
  192. x (list[Tensor]): list of multi-level img features.
  193. img_metas (list[dict]): list of image info dict where each dict
  194. has: 'img_shape', 'scale_factor', 'flip', and may also contain
  195. 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
  196. For details on the values of these keys see
  197. `mmdet/datasets/pipelines/formatting.py:Collect`.
  198. proposal_list (list[Tensors]): list of region proposals.
  199. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  200. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  201. gt_labels (list[Tensor]): class indices corresponding to each box
  202. gt_bboxes_ignore (None, list[Tensor]): specify which bounding
  203. boxes can be ignored when computing the loss.
  204. gt_masks (None, Tensor) : true segmentation masks for each box
  205. used if the architecture supports a segmentation task.
  206. gt_semantic_seg (None, list[Tensor]): semantic segmentation masks
  207. used if the architecture supports semantic segmentation task.
  208. Returns:
  209. dict[str, Tensor]: a dictionary of loss components
  210. """
  211. losses = dict()
  212. # semantic segmentation branch
  213. if self.with_semantic:
  214. semantic_pred, semantic_feat = self.semantic_head(x)
  215. loss_seg = self.semantic_head.loss(semantic_pred, gt_semantic_seg)
  216. losses['loss_semantic_seg'] = loss_seg
  217. else:
  218. semantic_feat = None
  219. # global context branch
  220. if self.with_glbctx:
  221. mc_pred, glbctx_feat = self.glbctx_head(x)
  222. loss_glbctx = self.glbctx_head.loss(mc_pred, gt_labels)
  223. losses['loss_glbctx'] = loss_glbctx
  224. else:
  225. glbctx_feat = None
  226. for i in range(self.num_stages):
  227. self.current_stage = i
  228. rcnn_train_cfg = self.train_cfg[i]
  229. lw = self.stage_loss_weights[i]
  230. # assign gts and sample proposals
  231. sampling_results = []
  232. bbox_assigner = self.bbox_assigner[i]
  233. bbox_sampler = self.bbox_sampler[i]
  234. num_imgs = len(img_metas)
  235. if gt_bboxes_ignore is None:
  236. gt_bboxes_ignore = [None for _ in range(num_imgs)]
  237. for j in range(num_imgs):
  238. assign_result = bbox_assigner.assign(proposal_list[j],
  239. gt_bboxes[j],
  240. gt_bboxes_ignore[j],
  241. gt_labels[j])
  242. sampling_result = bbox_sampler.sample(
  243. assign_result,
  244. proposal_list[j],
  245. gt_bboxes[j],
  246. gt_labels[j],
  247. feats=[lvl_feat[j][None] for lvl_feat in x])
  248. sampling_results.append(sampling_result)
  249. bbox_results = \
  250. self._bbox_forward_train(
  251. i, x, sampling_results, gt_bboxes, gt_labels,
  252. rcnn_train_cfg, semantic_feat, glbctx_feat)
  253. roi_labels = bbox_results['bbox_targets'][0]
  254. for name, value in bbox_results['loss_bbox'].items():
  255. losses[f's{i}.{name}'] = (
  256. value * lw if 'loss' in name else value)
  257. # refine boxes
  258. if i < self.num_stages - 1:
  259. pos_is_gts = [res.pos_is_gt for res in sampling_results]
  260. with torch.no_grad():
  261. proposal_list = self.bbox_head[i].refine_bboxes(
  262. bbox_results['rois'], roi_labels,
  263. bbox_results['bbox_pred'], pos_is_gts, img_metas)
  264. if self.with_feat_relay:
  265. relayed_feat = self._slice_pos_feats(bbox_results['relayed_feat'],
  266. sampling_results)
  267. relayed_feat = self.feat_relay_head(relayed_feat)
  268. else:
  269. relayed_feat = None
  270. mask_results = self._mask_forward_train(x, sampling_results, gt_masks,
  271. rcnn_train_cfg, semantic_feat,
  272. glbctx_feat, relayed_feat)
  273. mask_lw = sum(self.stage_loss_weights)
  274. losses['loss_mask'] = mask_lw * mask_results['loss_mask']
  275. return losses
  276. def simple_test(self, x, proposal_list, img_metas, rescale=False):
  277. """Test without augmentation.
  278. Args:
  279. x (tuple[Tensor]): Features from upstream network. Each
  280. has shape (batch_size, c, h, w).
  281. proposal_list (list(Tensor)): Proposals from rpn head.
  282. Each has shape (num_proposals, 5), last dimension
  283. 5 represent (x1, y1, x2, y2, score).
  284. img_metas (list[dict]): Meta information of images.
  285. rescale (bool): Whether to rescale the results to
  286. the original image. Default: True.
  287. Returns:
  288. list[list[np.ndarray]] or list[tuple]: When no mask branch,
  289. it is bbox results of each image and classes with type
  290. `list[list[np.ndarray]]`. The outer list
  291. corresponds to each image. The inner list
  292. corresponds to each class. When the model has mask branch,
  293. it contains bbox results and mask results.
  294. The outer list corresponds to each image, and first element
  295. of tuple is bbox results, second element is mask results.
  296. """
  297. if self.with_semantic:
  298. _, semantic_feat = self.semantic_head(x)
  299. else:
  300. semantic_feat = None
  301. if self.with_glbctx:
  302. mc_pred, glbctx_feat = self.glbctx_head(x)
  303. else:
  304. glbctx_feat = None
  305. num_imgs = len(proposal_list)
  306. img_shapes = tuple(meta['img_shape'] for meta in img_metas)
  307. ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
  308. scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
  309. # "ms" in variable names means multi-stage
  310. ms_scores = []
  311. rcnn_test_cfg = self.test_cfg
  312. rois = bbox2roi(proposal_list)
  313. if rois.shape[0] == 0:
  314. # There is no proposal in the whole batch
  315. bbox_results = [[
  316. np.zeros((0, 5), dtype=np.float32)
  317. for _ in range(self.bbox_head[-1].num_classes)
  318. ]] * num_imgs
  319. if self.with_mask:
  320. mask_classes = self.mask_head.num_classes
  321. segm_results = [[[] for _ in range(mask_classes)]
  322. for _ in range(num_imgs)]
  323. results = list(zip(bbox_results, segm_results))
  324. else:
  325. results = bbox_results
  326. return results
  327. for i in range(self.num_stages):
  328. bbox_head = self.bbox_head[i]
  329. bbox_results = self._bbox_forward(
  330. i,
  331. x,
  332. rois,
  333. semantic_feat=semantic_feat,
  334. glbctx_feat=glbctx_feat)
  335. # split batch bbox prediction back to each image
  336. cls_score = bbox_results['cls_score']
  337. bbox_pred = bbox_results['bbox_pred']
  338. num_proposals_per_img = tuple(len(p) for p in proposal_list)
  339. rois = rois.split(num_proposals_per_img, 0)
  340. cls_score = cls_score.split(num_proposals_per_img, 0)
  341. bbox_pred = bbox_pred.split(num_proposals_per_img, 0)
  342. ms_scores.append(cls_score)
  343. if i < self.num_stages - 1:
  344. refine_rois_list = []
  345. for j in range(num_imgs):
  346. if rois[j].shape[0] > 0:
  347. bbox_label = cls_score[j][:, :-1].argmax(dim=1)
  348. refine_rois = bbox_head.regress_by_class(
  349. rois[j], bbox_label, bbox_pred[j], img_metas[j])
  350. refine_rois_list.append(refine_rois)
  351. rois = torch.cat(refine_rois_list)
  352. # average scores of each image by stages
  353. cls_score = [
  354. sum([score[i] for score in ms_scores]) / float(len(ms_scores))
  355. for i in range(num_imgs)
  356. ]
  357. # apply bbox post-processing to each image individually
  358. det_bboxes = []
  359. det_labels = []
  360. for i in range(num_imgs):
  361. det_bbox, det_label = self.bbox_head[-1].get_bboxes(
  362. rois[i],
  363. cls_score[i],
  364. bbox_pred[i],
  365. img_shapes[i],
  366. scale_factors[i],
  367. rescale=rescale,
  368. cfg=rcnn_test_cfg)
  369. det_bboxes.append(det_bbox)
  370. det_labels.append(det_label)
  371. det_bbox_results = [
  372. bbox2result(det_bboxes[i], det_labels[i],
  373. self.bbox_head[-1].num_classes)
  374. for i in range(num_imgs)
  375. ]
  376. if self.with_mask:
  377. if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
  378. mask_classes = self.mask_head.num_classes
  379. det_segm_results = [[[] for _ in range(mask_classes)]
  380. for _ in range(num_imgs)]
  381. else:
  382. if rescale and not isinstance(scale_factors[0], float):
  383. scale_factors = [
  384. torch.from_numpy(scale_factor).to(det_bboxes[0].device)
  385. for scale_factor in scale_factors
  386. ]
  387. _bboxes = [
  388. det_bboxes[i][:, :4] *
  389. scale_factors[i] if rescale else det_bboxes[i]
  390. for i in range(num_imgs)
  391. ]
  392. mask_rois = bbox2roi(_bboxes)
  393. # get relay feature on mask_rois
  394. bbox_results = self._bbox_forward(
  395. -1,
  396. x,
  397. mask_rois,
  398. semantic_feat=semantic_feat,
  399. glbctx_feat=glbctx_feat)
  400. relayed_feat = bbox_results['relayed_feat']
  401. relayed_feat = self.feat_relay_head(relayed_feat)
  402. mask_results = self._mask_forward(
  403. x,
  404. mask_rois,
  405. semantic_feat=semantic_feat,
  406. glbctx_feat=glbctx_feat,
  407. relayed_feat=relayed_feat)
  408. mask_pred = mask_results['mask_pred']
  409. # split batch mask prediction back to each image
  410. num_bbox_per_img = tuple(len(_bbox) for _bbox in _bboxes)
  411. mask_preds = mask_pred.split(num_bbox_per_img, 0)
  412. # apply mask post-processing to each image individually
  413. det_segm_results = []
  414. for i in range(num_imgs):
  415. if det_bboxes[i].shape[0] == 0:
  416. det_segm_results.append(
  417. [[] for _ in range(self.mask_head.num_classes)])
  418. else:
  419. segm_result = self.mask_head.get_seg_masks(
  420. mask_preds[i], _bboxes[i], det_labels[i],
  421. self.test_cfg, ori_shapes[i], scale_factors[i],
  422. rescale)
  423. det_segm_results.append(segm_result)
  424. # return results
  425. if self.with_mask:
  426. return list(zip(det_bbox_results, det_segm_results))
  427. else:
  428. return det_bbox_results
  429. def aug_test(self, img_feats, proposal_list, img_metas, rescale=False):
  430. if self.with_semantic:
  431. semantic_feats = [
  432. self.semantic_head(feat)[1] for feat in img_feats
  433. ]
  434. else:
  435. semantic_feats = [None] * len(img_metas)
  436. if self.with_glbctx:
  437. glbctx_feats = [self.glbctx_head(feat)[1] for feat in img_feats]
  438. else:
  439. glbctx_feats = [None] * len(img_metas)
  440. rcnn_test_cfg = self.test_cfg
  441. aug_bboxes = []
  442. aug_scores = []
  443. for x, img_meta, semantic_feat, glbctx_feat in zip(
  444. img_feats, img_metas, semantic_feats, glbctx_feats):
  445. # only one image in the batch
  446. img_shape = img_meta[0]['img_shape']
  447. scale_factor = img_meta[0]['scale_factor']
  448. flip = img_meta[0]['flip']
  449. proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
  450. scale_factor, flip)
  451. # "ms" in variable names means multi-stage
  452. ms_scores = []
  453. rois = bbox2roi([proposals])
  454. if rois.shape[0] == 0:
  455. # There is no proposal in the single image
  456. aug_bboxes.append(rois.new_zeros(0, 4))
  457. aug_scores.append(rois.new_zeros(0, 1))
  458. continue
  459. for i in range(self.num_stages):
  460. bbox_head = self.bbox_head[i]
  461. bbox_results = self._bbox_forward(
  462. i,
  463. x,
  464. rois,
  465. semantic_feat=semantic_feat,
  466. glbctx_feat=glbctx_feat)
  467. ms_scores.append(bbox_results['cls_score'])
  468. if i < self.num_stages - 1:
  469. bbox_label = bbox_results['cls_score'].argmax(dim=1)
  470. rois = bbox_head.regress_by_class(
  471. rois, bbox_label, bbox_results['bbox_pred'],
  472. img_meta[0])
  473. cls_score = sum(ms_scores) / float(len(ms_scores))
  474. bboxes, scores = self.bbox_head[-1].get_bboxes(
  475. rois,
  476. cls_score,
  477. bbox_results['bbox_pred'],
  478. img_shape,
  479. scale_factor,
  480. rescale=False,
  481. cfg=None)
  482. aug_bboxes.append(bboxes)
  483. aug_scores.append(scores)
  484. # after merging, bboxes will be rescaled to the original image size
  485. merged_bboxes, merged_scores = merge_aug_bboxes(
  486. aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
  487. det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
  488. rcnn_test_cfg.score_thr,
  489. rcnn_test_cfg.nms,
  490. rcnn_test_cfg.max_per_img)
  491. det_bbox_results = bbox2result(det_bboxes, det_labels,
  492. self.bbox_head[-1].num_classes)
  493. if self.with_mask:
  494. if det_bboxes.shape[0] == 0:
  495. det_segm_results = [[]
  496. for _ in range(self.mask_head.num_classes)]
  497. else:
  498. aug_masks = []
  499. for x, img_meta, semantic_feat, glbctx_feat in zip(
  500. img_feats, img_metas, semantic_feats, glbctx_feats):
  501. img_shape = img_meta[0]['img_shape']
  502. scale_factor = img_meta[0]['scale_factor']
  503. flip = img_meta[0]['flip']
  504. _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
  505. scale_factor, flip)
  506. mask_rois = bbox2roi([_bboxes])
  507. # get relay feature on mask_rois
  508. bbox_results = self._bbox_forward(
  509. -1,
  510. x,
  511. mask_rois,
  512. semantic_feat=semantic_feat,
  513. glbctx_feat=glbctx_feat)
  514. relayed_feat = bbox_results['relayed_feat']
  515. relayed_feat = self.feat_relay_head(relayed_feat)
  516. mask_results = self._mask_forward(
  517. x,
  518. mask_rois,
  519. semantic_feat=semantic_feat,
  520. glbctx_feat=glbctx_feat,
  521. relayed_feat=relayed_feat)
  522. mask_pred = mask_results['mask_pred']
  523. aug_masks.append(mask_pred.sigmoid().cpu().numpy())
  524. merged_masks = merge_aug_masks(aug_masks, img_metas,
  525. self.test_cfg)
  526. ori_shape = img_metas[0][0]['ori_shape']
  527. det_segm_results = self.mask_head.get_seg_masks(
  528. merged_masks,
  529. det_bboxes,
  530. det_labels,
  531. rcnn_test_cfg,
  532. ori_shape,
  533. scale_factor=1.0,
  534. rescale=False)
  535. return [(det_bbox_results, det_segm_results)]
  536. else:
  537. return [det_bbox_results]

No Description

Contributors (3)