You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

trident_faster_rcnn.py 2.9 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from ..builder import DETECTORS
  3. from .faster_rcnn import FasterRCNN
  4. @DETECTORS.register_module()
  5. class TridentFasterRCNN(FasterRCNN):
  6. """Implementation of `TridentNet <https://arxiv.org/abs/1901.01892>`_"""
  7. def __init__(self,
  8. backbone,
  9. rpn_head,
  10. roi_head,
  11. train_cfg,
  12. test_cfg,
  13. neck=None,
  14. pretrained=None,
  15. init_cfg=None):
  16. super(TridentFasterRCNN, self).__init__(
  17. backbone=backbone,
  18. neck=neck,
  19. rpn_head=rpn_head,
  20. roi_head=roi_head,
  21. train_cfg=train_cfg,
  22. test_cfg=test_cfg,
  23. pretrained=pretrained,
  24. init_cfg=init_cfg)
  25. assert self.backbone.num_branch == self.roi_head.num_branch
  26. assert self.backbone.test_branch_idx == self.roi_head.test_branch_idx
  27. self.num_branch = self.backbone.num_branch
  28. self.test_branch_idx = self.backbone.test_branch_idx
  29. def simple_test(self, img, img_metas, proposals=None, rescale=False):
  30. """Test without augmentation."""
  31. assert self.with_bbox, 'Bbox head must be implemented.'
  32. x = self.extract_feat(img)
  33. if proposals is None:
  34. num_branch = (self.num_branch if self.test_branch_idx == -1 else 1)
  35. trident_img_metas = img_metas * num_branch
  36. proposal_list = self.rpn_head.simple_test_rpn(x, trident_img_metas)
  37. else:
  38. proposal_list = proposals
  39. # TODO: Fix trident_img_metas undefined errors
  40. # when proposals is specified
  41. return self.roi_head.simple_test(
  42. x, proposal_list, trident_img_metas, rescale=rescale)
  43. def aug_test(self, imgs, img_metas, rescale=False):
  44. """Test with augmentations.
  45. If rescale is False, then returned bboxes and masks will fit the scale
  46. of imgs[0].
  47. """
  48. x = self.extract_feats(imgs)
  49. num_branch = (self.num_branch if self.test_branch_idx == -1 else 1)
  50. trident_img_metas = [img_metas * num_branch for img_metas in img_metas]
  51. proposal_list = self.rpn_head.aug_test_rpn(x, trident_img_metas)
  52. return self.roi_head.aug_test(
  53. x, proposal_list, img_metas, rescale=rescale)
  54. def forward_train(self, img, img_metas, gt_bboxes, gt_labels, **kwargs):
  55. """make copies of img and gts to fit multi-branch."""
  56. trident_gt_bboxes = tuple(gt_bboxes * self.num_branch)
  57. trident_gt_labels = tuple(gt_labels * self.num_branch)
  58. trident_img_metas = tuple(img_metas * self.num_branch)
  59. return super(TridentFasterRCNN,
  60. self).forward_train(img, trident_img_metas,
  61. trident_gt_bboxes, trident_gt_labels)

No Description

Contributors (1)