You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

double_roi_head.py 1.3 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from ..builder import HEADS
  3. from .standard_roi_head import StandardRoIHead
  4. @HEADS.register_module()
  5. class DoubleHeadRoIHead(StandardRoIHead):
  6. """RoI head for Double Head RCNN.
  7. https://arxiv.org/abs/1904.06493
  8. """
  9. def __init__(self, reg_roi_scale_factor, **kwargs):
  10. super(DoubleHeadRoIHead, self).__init__(**kwargs)
  11. self.reg_roi_scale_factor = reg_roi_scale_factor
  12. def _bbox_forward(self, x, rois):
  13. """Box head forward function used in both training and testing time."""
  14. bbox_cls_feats = self.bbox_roi_extractor(
  15. x[:self.bbox_roi_extractor.num_inputs], rois)
  16. bbox_reg_feats = self.bbox_roi_extractor(
  17. x[:self.bbox_roi_extractor.num_inputs],
  18. rois,
  19. roi_scale_factor=self.reg_roi_scale_factor)
  20. if self.with_shared_head:
  21. bbox_cls_feats = self.shared_head(bbox_cls_feats)
  22. bbox_reg_feats = self.shared_head(bbox_reg_feats)
  23. cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)
  24. bbox_results = dict(
  25. cls_score=cls_score,
  26. bbox_pred=bbox_pred,
  27. bbox_feats=bbox_cls_feats)
  28. return bbox_results

No Description

Contributors (3)