You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cascade_roi_head.py 28 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. from mmcv.runner import ModuleList
  6. from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, build_assigner,
  7. build_sampler, merge_aug_bboxes, merge_aug_masks,
  8. multiclass_nms)
  9. from ..builder import HEADS, build_head, build_roi_extractor
  10. from .base_roi_head import BaseRoIHead
  11. from .test_mixins import BBoxTestMixin, MaskTestMixin
  12. @HEADS.register_module()
  13. class CascadeRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
  14. """Cascade roi head including one bbox head and one mask head.
  15. https://arxiv.org/abs/1712.00726
  16. """
  17. def __init__(self,
  18. num_stages,
  19. stage_loss_weights,
  20. bbox_roi_extractor=None,
  21. bbox_head=None,
  22. mask_roi_extractor=None,
  23. mask_head=None,
  24. shared_head=None,
  25. train_cfg=None,
  26. test_cfg=None,
  27. pretrained=None,
  28. init_cfg=None):
  29. assert bbox_roi_extractor is not None
  30. assert bbox_head is not None
  31. assert shared_head is None, \
  32. 'Shared head is not supported in Cascade RCNN anymore'
  33. self.num_stages = num_stages
  34. self.stage_loss_weights = stage_loss_weights
  35. super(CascadeRoIHead, self).__init__(
  36. bbox_roi_extractor=bbox_roi_extractor,
  37. bbox_head=bbox_head,
  38. mask_roi_extractor=mask_roi_extractor,
  39. mask_head=mask_head,
  40. shared_head=shared_head,
  41. train_cfg=train_cfg,
  42. test_cfg=test_cfg,
  43. pretrained=pretrained,
  44. init_cfg=init_cfg)
  45. def init_bbox_head(self, bbox_roi_extractor, bbox_head):
  46. """Initialize box head and box roi extractor.
  47. Args:
  48. bbox_roi_extractor (dict): Config of box roi extractor.
  49. bbox_head (dict): Config of box in box head.
  50. """
  51. self.bbox_roi_extractor = ModuleList()
  52. self.bbox_head = ModuleList()
  53. if not isinstance(bbox_roi_extractor, list):
  54. bbox_roi_extractor = [
  55. bbox_roi_extractor for _ in range(self.num_stages)
  56. ]
  57. if not isinstance(bbox_head, list):
  58. bbox_head = [bbox_head for _ in range(self.num_stages)]
  59. assert len(bbox_roi_extractor) == len(bbox_head) == self.num_stages
  60. for roi_extractor, head in zip(bbox_roi_extractor, bbox_head):
  61. self.bbox_roi_extractor.append(build_roi_extractor(roi_extractor))
  62. self.bbox_head.append(build_head(head))
  63. def init_mask_head(self, mask_roi_extractor, mask_head):
  64. """Initialize mask head and mask roi extractor.
  65. Args:
  66. mask_roi_extractor (dict): Config of mask roi extractor.
  67. mask_head (dict): Config of mask in mask head.
  68. """
  69. self.mask_head = nn.ModuleList()
  70. if not isinstance(mask_head, list):
  71. mask_head = [mask_head for _ in range(self.num_stages)]
  72. assert len(mask_head) == self.num_stages
  73. for head in mask_head:
  74. self.mask_head.append(build_head(head))
  75. if mask_roi_extractor is not None:
  76. self.share_roi_extractor = False
  77. self.mask_roi_extractor = ModuleList()
  78. if not isinstance(mask_roi_extractor, list):
  79. mask_roi_extractor = [
  80. mask_roi_extractor for _ in range(self.num_stages)
  81. ]
  82. assert len(mask_roi_extractor) == self.num_stages
  83. for roi_extractor in mask_roi_extractor:
  84. self.mask_roi_extractor.append(
  85. build_roi_extractor(roi_extractor))
  86. else:
  87. self.share_roi_extractor = True
  88. self.mask_roi_extractor = self.bbox_roi_extractor
  89. def init_assigner_sampler(self):
  90. """Initialize assigner and sampler for each stage."""
  91. self.bbox_assigner = []
  92. self.bbox_sampler = []
  93. if self.train_cfg is not None:
  94. for idx, rcnn_train_cfg in enumerate(self.train_cfg):
  95. self.bbox_assigner.append(
  96. build_assigner(rcnn_train_cfg.assigner))
  97. self.current_stage = idx
  98. self.bbox_sampler.append(
  99. build_sampler(rcnn_train_cfg.sampler, context=self))
  100. def forward_dummy(self, x, proposals):
  101. """Dummy forward function."""
  102. # bbox head
  103. outs = ()
  104. rois = bbox2roi([proposals])
  105. if self.with_bbox:
  106. for i in range(self.num_stages):
  107. bbox_results = self._bbox_forward(i, x, rois)
  108. outs = outs + (bbox_results['cls_score'],
  109. bbox_results['bbox_pred'])
  110. # mask heads
  111. if self.with_mask:
  112. mask_rois = rois[:100]
  113. for i in range(self.num_stages):
  114. mask_results = self._mask_forward(i, x, mask_rois)
  115. outs = outs + (mask_results['mask_pred'], )
  116. return outs
  117. def _bbox_forward(self, stage, x, rois):
  118. """Box head forward function used in both training and testing."""
  119. bbox_roi_extractor = self.bbox_roi_extractor[stage]
  120. bbox_head = self.bbox_head[stage]
  121. bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
  122. rois)
  123. # do not support caffe_c4 model anymore
  124. cls_score, bbox_pred = bbox_head(bbox_feats)
  125. bbox_results = dict(
  126. cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats)
  127. return bbox_results
  128. def _bbox_forward_train(self, stage, x, sampling_results, gt_bboxes,
  129. gt_labels, rcnn_train_cfg):
  130. """Run forward function and calculate loss for box head in training."""
  131. rois = bbox2roi([res.bboxes for res in sampling_results])
  132. bbox_results = self._bbox_forward(stage, x, rois)
  133. bbox_targets = self.bbox_head[stage].get_targets(
  134. sampling_results, gt_bboxes, gt_labels, rcnn_train_cfg)
  135. loss_bbox = self.bbox_head[stage].loss(bbox_results['cls_score'],
  136. bbox_results['bbox_pred'], rois,
  137. *bbox_targets)
  138. bbox_results.update(
  139. loss_bbox=loss_bbox, rois=rois, bbox_targets=bbox_targets)
  140. return bbox_results
  141. def _mask_forward(self, stage, x, rois):
  142. """Mask head forward function used in both training and testing."""
  143. mask_roi_extractor = self.mask_roi_extractor[stage]
  144. mask_head = self.mask_head[stage]
  145. mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs],
  146. rois)
  147. # do not support caffe_c4 model anymore
  148. mask_pred = mask_head(mask_feats)
  149. mask_results = dict(mask_pred=mask_pred)
  150. return mask_results
  151. def _mask_forward_train(self,
  152. stage,
  153. x,
  154. sampling_results,
  155. gt_masks,
  156. rcnn_train_cfg,
  157. bbox_feats=None):
  158. """Run forward function and calculate loss for mask head in
  159. training."""
  160. pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
  161. mask_results = self._mask_forward(stage, x, pos_rois)
  162. mask_targets = self.mask_head[stage].get_targets(
  163. sampling_results, gt_masks, rcnn_train_cfg)
  164. pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
  165. loss_mask = self.mask_head[stage].loss(mask_results['mask_pred'],
  166. mask_targets, pos_labels)
  167. mask_results.update(loss_mask=loss_mask)
  168. return mask_results
  169. def forward_train(self,
  170. x,
  171. img_metas,
  172. proposal_list,
  173. gt_bboxes,
  174. gt_labels,
  175. gt_bboxes_ignore=None,
  176. gt_masks=None):
  177. """
  178. Args:
  179. x (list[Tensor]): list of multi-level img features.
  180. img_metas (list[dict]): list of image info dict where each dict
  181. has: 'img_shape', 'scale_factor', 'flip', and may also contain
  182. 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
  183. For details on the values of these keys see
  184. `mmdet/datasets/pipelines/formatting.py:Collect`.
  185. proposals (list[Tensors]): list of region proposals.
  186. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  187. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  188. gt_labels (list[Tensor]): class indices corresponding to each box
  189. gt_bboxes_ignore (None | list[Tensor]): specify which bounding
  190. boxes can be ignored when computing the loss.
  191. gt_masks (None | Tensor) : true segmentation masks for each box
  192. used if the architecture supports a segmentation task.
  193. Returns:
  194. dict[str, Tensor]: a dictionary of loss components
  195. """
  196. losses = dict()
  197. for i in range(self.num_stages):
  198. self.current_stage = i
  199. rcnn_train_cfg = self.train_cfg[i]
  200. lw = self.stage_loss_weights[i]
  201. # assign gts and sample proposals
  202. sampling_results = []
  203. if self.with_bbox or self.with_mask:
  204. bbox_assigner = self.bbox_assigner[i]
  205. bbox_sampler = self.bbox_sampler[i]
  206. num_imgs = len(img_metas)
  207. if gt_bboxes_ignore is None:
  208. gt_bboxes_ignore = [None for _ in range(num_imgs)]
  209. for j in range(num_imgs):
  210. assign_result = bbox_assigner.assign(
  211. proposal_list[j], gt_bboxes[j], gt_bboxes_ignore[j],
  212. gt_labels[j])
  213. sampling_result = bbox_sampler.sample(
  214. assign_result,
  215. proposal_list[j],
  216. gt_bboxes[j],
  217. gt_labels[j],
  218. feats=[lvl_feat[j][None] for lvl_feat in x])
  219. sampling_results.append(sampling_result)
  220. # bbox head forward and loss
  221. bbox_results = self._bbox_forward_train(i, x, sampling_results,
  222. gt_bboxes, gt_labels,
  223. rcnn_train_cfg)
  224. for name, value in bbox_results['loss_bbox'].items():
  225. losses[f's{i}.{name}'] = (
  226. value * lw if 'loss' in name else value)
  227. # mask head forward and loss
  228. if self.with_mask:
  229. mask_results = self._mask_forward_train(
  230. i, x, sampling_results, gt_masks, rcnn_train_cfg,
  231. bbox_results['bbox_feats'])
  232. for name, value in mask_results['loss_mask'].items():
  233. losses[f's{i}.{name}'] = (
  234. value * lw if 'loss' in name else value)
  235. # refine bboxes
  236. if i < self.num_stages - 1:
  237. pos_is_gts = [res.pos_is_gt for res in sampling_results]
  238. # bbox_targets is a tuple
  239. roi_labels = bbox_results['bbox_targets'][0]
  240. with torch.no_grad():
  241. cls_score = bbox_results['cls_score']
  242. if self.bbox_head[i].custom_activation:
  243. cls_score = self.bbox_head[i].loss_cls.get_activation(
  244. cls_score)
  245. # Empty proposal.
  246. if cls_score.numel() == 0:
  247. break
  248. roi_labels = torch.where(
  249. roi_labels == self.bbox_head[i].num_classes,
  250. cls_score[:, :-1].argmax(1), roi_labels)
  251. proposal_list = self.bbox_head[i].refine_bboxes(
  252. bbox_results['rois'], roi_labels,
  253. bbox_results['bbox_pred'], pos_is_gts, img_metas)
  254. return losses
  255. def simple_test(self, x, proposal_list, img_metas, rescale=False):
  256. """Test without augmentation.
  257. Args:
  258. x (tuple[Tensor]): Features from upstream network. Each
  259. has shape (batch_size, c, h, w).
  260. proposal_list (list(Tensor)): Proposals from rpn head.
  261. Each has shape (num_proposals, 5), last dimension
  262. 5 represent (x1, y1, x2, y2, score).
  263. img_metas (list[dict]): Meta information of images.
  264. rescale (bool): Whether to rescale the results to
  265. the original image. Default: True.
  266. Returns:
  267. list[list[np.ndarray]] or list[tuple]: When no mask branch,
  268. it is bbox results of each image and classes with type
  269. `list[list[np.ndarray]]`. The outer list
  270. corresponds to each image. The inner list
  271. corresponds to each class. When the model has mask branch,
  272. it contains bbox results and mask results.
  273. The outer list corresponds to each image, and first element
  274. of tuple is bbox results, second element is mask results.
  275. """
  276. assert self.with_bbox, 'Bbox head must be implemented.'
  277. num_imgs = len(proposal_list)
  278. img_shapes = tuple(meta['img_shape'] for meta in img_metas)
  279. ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
  280. scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
  281. # "ms" in variable names means multi-stage
  282. ms_bbox_result = {}
  283. ms_segm_result = {}
  284. ms_scores = []
  285. rcnn_test_cfg = self.test_cfg
  286. rois = bbox2roi(proposal_list)
  287. if rois.shape[0] == 0:
  288. # There is no proposal in the whole batch
  289. bbox_results = [[
  290. np.zeros((0, 5), dtype=np.float32)
  291. for _ in range(self.bbox_head[-1].num_classes)
  292. ]] * num_imgs
  293. if self.with_mask:
  294. mask_classes = self.mask_head[-1].num_classes
  295. segm_results = [[[] for _ in range(mask_classes)]
  296. for _ in range(num_imgs)]
  297. results = list(zip(bbox_results, segm_results))
  298. else:
  299. results = bbox_results
  300. return results
  301. for i in range(self.num_stages):
  302. bbox_results = self._bbox_forward(i, x, rois)
  303. # split batch bbox prediction back to each image
  304. cls_score = bbox_results['cls_score']
  305. bbox_pred = bbox_results['bbox_pred']
  306. num_proposals_per_img = tuple(
  307. len(proposals) for proposals in proposal_list)
  308. rois = rois.split(num_proposals_per_img, 0)
  309. cls_score = cls_score.split(num_proposals_per_img, 0)
  310. if isinstance(bbox_pred, torch.Tensor):
  311. bbox_pred = bbox_pred.split(num_proposals_per_img, 0)
  312. else:
  313. bbox_pred = self.bbox_head[i].bbox_pred_split(
  314. bbox_pred, num_proposals_per_img)
  315. ms_scores.append(cls_score)
  316. if i < self.num_stages - 1:
  317. if self.bbox_head[i].custom_activation:
  318. cls_score = [
  319. self.bbox_head[i].loss_cls.get_activation(s)
  320. for s in cls_score
  321. ]
  322. refine_rois_list = []
  323. for j in range(num_imgs):
  324. if rois[j].shape[0] > 0:
  325. bbox_label = cls_score[j][:, :-1].argmax(dim=1)
  326. refined_rois = self.bbox_head[i].regress_by_class(
  327. rois[j], bbox_label, bbox_pred[j], img_metas[j])
  328. refine_rois_list.append(refined_rois)
  329. rois = torch.cat(refine_rois_list)
  330. # average scores of each image by stages
  331. cls_score = [
  332. sum([score[i] for score in ms_scores]) / float(len(ms_scores))
  333. for i in range(num_imgs)
  334. ]
  335. # apply bbox post-processing to each image individually
  336. det_bboxes = []
  337. det_labels = []
  338. for i in range(num_imgs):
  339. det_bbox, det_label = self.bbox_head[-1].get_bboxes(
  340. rois[i],
  341. cls_score[i],
  342. bbox_pred[i],
  343. img_shapes[i],
  344. scale_factors[i],
  345. rescale=rescale,
  346. cfg=rcnn_test_cfg)
  347. det_bboxes.append(det_bbox)
  348. det_labels.append(det_label)
  349. bbox_results = [
  350. bbox2result(det_bboxes[i], det_labels[i],
  351. self.bbox_head[-1].num_classes)
  352. for i in range(num_imgs)
  353. ]
  354. ms_bbox_result['ensemble'] = bbox_results
  355. if self.with_mask:
  356. if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
  357. mask_classes = self.mask_head[-1].num_classes
  358. segm_results = [[[] for _ in range(mask_classes)]
  359. for _ in range(num_imgs)]
  360. else:
  361. if rescale and not isinstance(scale_factors[0], float):
  362. scale_factors = [
  363. torch.from_numpy(scale_factor).to(det_bboxes[0].device)
  364. for scale_factor in scale_factors
  365. ]
  366. _bboxes = [
  367. det_bboxes[i][:, :4] *
  368. scale_factors[i] if rescale else det_bboxes[i][:, :4]
  369. for i in range(len(det_bboxes))
  370. ]
  371. mask_rois = bbox2roi(_bboxes)
  372. num_mask_rois_per_img = tuple(
  373. _bbox.size(0) for _bbox in _bboxes)
  374. aug_masks = []
  375. for i in range(self.num_stages):
  376. mask_results = self._mask_forward(i, x, mask_rois)
  377. mask_pred = mask_results['mask_pred']
  378. # split batch mask prediction back to each image
  379. mask_pred = mask_pred.split(num_mask_rois_per_img, 0)
  380. aug_masks.append([
  381. m.sigmoid().cpu().detach().numpy() for m in mask_pred
  382. ])
  383. # apply mask post-processing to each image individually
  384. segm_results = []
  385. for i in range(num_imgs):
  386. if det_bboxes[i].shape[0] == 0:
  387. segm_results.append(
  388. [[]
  389. for _ in range(self.mask_head[-1].num_classes)])
  390. else:
  391. aug_mask = [mask[i] for mask in aug_masks]
  392. merged_masks = merge_aug_masks(
  393. aug_mask, [[img_metas[i]]] * self.num_stages,
  394. rcnn_test_cfg)
  395. segm_result = self.mask_head[-1].get_seg_masks(
  396. merged_masks, _bboxes[i], det_labels[i],
  397. rcnn_test_cfg, ori_shapes[i], scale_factors[i],
  398. rescale)
  399. segm_results.append(segm_result)
  400. ms_segm_result['ensemble'] = segm_results
  401. if self.with_mask:
  402. results = list(
  403. zip(ms_bbox_result['ensemble'], ms_segm_result['ensemble']))
  404. else:
  405. results = ms_bbox_result['ensemble']
  406. return results
  407. def aug_test(self, features, proposal_list, img_metas, rescale=False):
  408. """Test with augmentations.
  409. If rescale is False, then returned bboxes and masks will fit the scale
  410. of imgs[0].
  411. """
  412. rcnn_test_cfg = self.test_cfg
  413. aug_bboxes = []
  414. aug_scores = []
  415. for x, img_meta in zip(features, img_metas):
  416. # only one image in the batch
  417. img_shape = img_meta[0]['img_shape']
  418. scale_factor = img_meta[0]['scale_factor']
  419. flip = img_meta[0]['flip']
  420. flip_direction = img_meta[0]['flip_direction']
  421. proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
  422. scale_factor, flip, flip_direction)
  423. # "ms" in variable names means multi-stage
  424. ms_scores = []
  425. rois = bbox2roi([proposals])
  426. if rois.shape[0] == 0:
  427. # There is no proposal in the single image
  428. aug_bboxes.append(rois.new_zeros(0, 4))
  429. aug_scores.append(rois.new_zeros(0, 1))
  430. continue
  431. for i in range(self.num_stages):
  432. bbox_results = self._bbox_forward(i, x, rois)
  433. ms_scores.append(bbox_results['cls_score'])
  434. if i < self.num_stages - 1:
  435. cls_score = bbox_results['cls_score']
  436. if self.bbox_head[i].custom_activation:
  437. cls_score = self.bbox_head[i].loss_cls.get_activation(
  438. cls_score)
  439. bbox_label = cls_score[:, :-1].argmax(dim=1)
  440. rois = self.bbox_head[i].regress_by_class(
  441. rois, bbox_label, bbox_results['bbox_pred'],
  442. img_meta[0])
  443. cls_score = sum(ms_scores) / float(len(ms_scores))
  444. bboxes, scores = self.bbox_head[-1].get_bboxes(
  445. rois,
  446. cls_score,
  447. bbox_results['bbox_pred'],
  448. img_shape,
  449. scale_factor,
  450. rescale=False,
  451. cfg=None)
  452. aug_bboxes.append(bboxes)
  453. aug_scores.append(scores)
  454. # after merging, bboxes will be rescaled to the original image size
  455. merged_bboxes, merged_scores = merge_aug_bboxes(
  456. aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
  457. det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
  458. rcnn_test_cfg.score_thr,
  459. rcnn_test_cfg.nms,
  460. rcnn_test_cfg.max_per_img)
  461. bbox_result = bbox2result(det_bboxes, det_labels,
  462. self.bbox_head[-1].num_classes)
  463. if self.with_mask:
  464. if det_bboxes.shape[0] == 0:
  465. segm_result = [[]
  466. for _ in range(self.mask_head[-1].num_classes)]
  467. else:
  468. aug_masks = []
  469. aug_img_metas = []
  470. for x, img_meta in zip(features, img_metas):
  471. img_shape = img_meta[0]['img_shape']
  472. scale_factor = img_meta[0]['scale_factor']
  473. flip = img_meta[0]['flip']
  474. flip_direction = img_meta[0]['flip_direction']
  475. _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
  476. scale_factor, flip, flip_direction)
  477. mask_rois = bbox2roi([_bboxes])
  478. for i in range(self.num_stages):
  479. mask_results = self._mask_forward(i, x, mask_rois)
  480. aug_masks.append(
  481. mask_results['mask_pred'].sigmoid().cpu().numpy())
  482. aug_img_metas.append(img_meta)
  483. merged_masks = merge_aug_masks(aug_masks, aug_img_metas,
  484. self.test_cfg)
  485. ori_shape = img_metas[0][0]['ori_shape']
  486. dummy_scale_factor = np.ones(4)
  487. segm_result = self.mask_head[-1].get_seg_masks(
  488. merged_masks,
  489. det_bboxes,
  490. det_labels,
  491. rcnn_test_cfg,
  492. ori_shape,
  493. scale_factor=dummy_scale_factor,
  494. rescale=False)
  495. return [(bbox_result, segm_result)]
  496. else:
  497. return [bbox_result]
  498. def onnx_export(self, x, proposals, img_metas):
  499. assert self.with_bbox, 'Bbox head must be implemented.'
  500. assert proposals.shape[0] == 1, 'Only support one input image ' \
  501. 'while in exporting to ONNX'
  502. # remove the scores
  503. rois = proposals[..., :-1]
  504. batch_size = rois.shape[0]
  505. num_proposals_per_img = rois.shape[1]
  506. # Eliminate the batch dimension
  507. rois = rois.view(-1, 4)
  508. # add dummy batch index
  509. rois = torch.cat([rois.new_zeros(rois.shape[0], 1), rois], dim=-1)
  510. max_shape = img_metas[0]['img_shape_for_onnx']
  511. ms_scores = []
  512. rcnn_test_cfg = self.test_cfg
  513. for i in range(self.num_stages):
  514. bbox_results = self._bbox_forward(i, x, rois)
  515. cls_score = bbox_results['cls_score']
  516. bbox_pred = bbox_results['bbox_pred']
  517. # Recover the batch dimension
  518. rois = rois.reshape(batch_size, num_proposals_per_img,
  519. rois.size(-1))
  520. cls_score = cls_score.reshape(batch_size, num_proposals_per_img,
  521. cls_score.size(-1))
  522. bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, 4)
  523. ms_scores.append(cls_score)
  524. if i < self.num_stages - 1:
  525. assert self.bbox_head[i].reg_class_agnostic
  526. new_rois = self.bbox_head[i].bbox_coder.decode(
  527. rois[..., 1:], bbox_pred, max_shape=max_shape)
  528. rois = new_rois.reshape(-1, new_rois.shape[-1])
  529. # add dummy batch index
  530. rois = torch.cat([rois.new_zeros(rois.shape[0], 1), rois],
  531. dim=-1)
  532. cls_score = sum(ms_scores) / float(len(ms_scores))
  533. bbox_pred = bbox_pred.reshape(batch_size, num_proposals_per_img, 4)
  534. rois = rois.reshape(batch_size, num_proposals_per_img, -1)
  535. det_bboxes, det_labels = self.bbox_head[-1].onnx_export(
  536. rois, cls_score, bbox_pred, max_shape, cfg=rcnn_test_cfg)
  537. if not self.with_mask:
  538. return det_bboxes, det_labels
  539. else:
  540. batch_index = torch.arange(
  541. det_bboxes.size(0),
  542. device=det_bboxes.device).float().view(-1, 1, 1).expand(
  543. det_bboxes.size(0), det_bboxes.size(1), 1)
  544. rois = det_bboxes[..., :4]
  545. mask_rois = torch.cat([batch_index, rois], dim=-1)
  546. mask_rois = mask_rois.view(-1, 5)
  547. aug_masks = []
  548. for i in range(self.num_stages):
  549. mask_results = self._mask_forward(i, x, mask_rois)
  550. mask_pred = mask_results['mask_pred']
  551. aug_masks.append(mask_pred)
  552. max_shape = img_metas[0]['img_shape_for_onnx']
  553. # calculate the mean of masks from several stage
  554. mask_pred = sum(aug_masks) / len(aug_masks)
  555. segm_results = self.mask_head[-1].onnx_export(
  556. mask_pred, rois.reshape(-1, 4), det_labels.reshape(-1),
  557. self.test_cfg, max_shape)
  558. segm_results = segm_results.reshape(batch_size,
  559. det_bboxes.shape[1],
  560. max_shape[0], max_shape[1])
  561. return det_bboxes, det_labels, segm_results

No Description

Contributors (1)