You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_roi_head.py 3.2 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from abc import ABCMeta, abstractmethod
  3. from mmcv.runner import BaseModule
  4. from ..builder import build_shared_head
  5. class BaseRoIHead(BaseModule, metaclass=ABCMeta):
  6. """Base class for RoIHeads."""
  7. def __init__(self,
  8. bbox_roi_extractor=None,
  9. bbox_head=None,
  10. mask_roi_extractor=None,
  11. mask_head=None,
  12. shared_head=None,
  13. train_cfg=None,
  14. test_cfg=None,
  15. pretrained=None,
  16. init_cfg=None):
  17. super(BaseRoIHead, self).__init__(init_cfg)
  18. self.train_cfg = train_cfg
  19. self.test_cfg = test_cfg
  20. if shared_head is not None:
  21. shared_head.pretrained = pretrained
  22. self.shared_head = build_shared_head(shared_head)
  23. if bbox_head is not None:
  24. self.init_bbox_head(bbox_roi_extractor, bbox_head)
  25. if mask_head is not None:
  26. self.init_mask_head(mask_roi_extractor, mask_head)
  27. self.init_assigner_sampler()
  28. @property
  29. def with_bbox(self):
  30. """bool: whether the RoI head contains a `bbox_head`"""
  31. return hasattr(self, 'bbox_head') and self.bbox_head is not None
  32. @property
  33. def with_mask(self):
  34. """bool: whether the RoI head contains a `mask_head`"""
  35. return hasattr(self, 'mask_head') and self.mask_head is not None
  36. @property
  37. def with_shared_head(self):
  38. """bool: whether the RoI head contains a `shared_head`"""
  39. return hasattr(self, 'shared_head') and self.shared_head is not None
  40. @abstractmethod
  41. def init_bbox_head(self):
  42. """Initialize ``bbox_head``"""
  43. pass
  44. @abstractmethod
  45. def init_mask_head(self):
  46. """Initialize ``mask_head``"""
  47. pass
  48. @abstractmethod
  49. def init_assigner_sampler(self):
  50. """Initialize assigner and sampler."""
  51. pass
  52. @abstractmethod
  53. def forward_train(self,
  54. x,
  55. img_meta,
  56. proposal_list,
  57. gt_bboxes,
  58. gt_labels,
  59. gt_bboxes_ignore=None,
  60. gt_masks=None,
  61. **kwargs):
  62. """Forward function during training."""
  63. async def async_simple_test(self,
  64. x,
  65. proposal_list,
  66. img_metas,
  67. proposals=None,
  68. rescale=False,
  69. **kwargs):
  70. """Asynchronized test function."""
  71. raise NotImplementedError
  72. def simple_test(self,
  73. x,
  74. proposal_list,
  75. img_meta,
  76. proposals=None,
  77. rescale=False,
  78. **kwargs):
  79. """Test without augmentation."""
  80. def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs):
  81. """Test with augmentations.
  82. If rescale is False, then returned bboxes and masks will fit the scale
  83. of imgs[0].
  84. """

No Description

Contributors (2)