You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pisa_roi_head.py 6.7 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmdet.core import bbox2roi
  3. from ..builder import HEADS
  4. from ..losses.pisa_loss import carl_loss, isr_p
  5. from .standard_roi_head import StandardRoIHead
  6. @HEADS.register_module()
  7. class PISARoIHead(StandardRoIHead):
  8. r"""The RoI head for `Prime Sample Attention in Object Detection
  9. <https://arxiv.org/abs/1904.04821>`_."""
  10. def forward_train(self,
  11. x,
  12. img_metas,
  13. proposal_list,
  14. gt_bboxes,
  15. gt_labels,
  16. gt_bboxes_ignore=None,
  17. gt_masks=None):
  18. """Forward function for training.
  19. Args:
  20. x (list[Tensor]): List of multi-level img features.
  21. img_metas (list[dict]): List of image info dict where each dict
  22. has: 'img_shape', 'scale_factor', 'flip', and may also contain
  23. 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
  24. For details on the values of these keys see
  25. `mmdet/datasets/pipelines/formatting.py:Collect`.
  26. proposals (list[Tensors]): List of region proposals.
  27. gt_bboxes (list[Tensor]): Each item are the truth boxes for each
  28. image in [tl_x, tl_y, br_x, br_y] format.
  29. gt_labels (list[Tensor]): Class indices corresponding to each box
  30. gt_bboxes_ignore (list[Tensor], optional): Specify which bounding
  31. boxes can be ignored when computing the loss.
  32. gt_masks (None | Tensor) : True segmentation masks for each box
  33. used if the architecture supports a segmentation task.
  34. Returns:
  35. dict[str, Tensor]: a dictionary of loss components
  36. """
  37. # assign gts and sample proposals
  38. if self.with_bbox or self.with_mask:
  39. num_imgs = len(img_metas)
  40. if gt_bboxes_ignore is None:
  41. gt_bboxes_ignore = [None for _ in range(num_imgs)]
  42. sampling_results = []
  43. neg_label_weights = []
  44. for i in range(num_imgs):
  45. assign_result = self.bbox_assigner.assign(
  46. proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
  47. gt_labels[i])
  48. sampling_result = self.bbox_sampler.sample(
  49. assign_result,
  50. proposal_list[i],
  51. gt_bboxes[i],
  52. gt_labels[i],
  53. feats=[lvl_feat[i][None] for lvl_feat in x])
  54. # neg label weight is obtained by sampling when using ISR-N
  55. neg_label_weight = None
  56. if isinstance(sampling_result, tuple):
  57. sampling_result, neg_label_weight = sampling_result
  58. sampling_results.append(sampling_result)
  59. neg_label_weights.append(neg_label_weight)
  60. losses = dict()
  61. # bbox head forward and loss
  62. if self.with_bbox:
  63. bbox_results = self._bbox_forward_train(
  64. x,
  65. sampling_results,
  66. gt_bboxes,
  67. gt_labels,
  68. img_metas,
  69. neg_label_weights=neg_label_weights)
  70. losses.update(bbox_results['loss_bbox'])
  71. # mask head forward and loss
  72. if self.with_mask:
  73. mask_results = self._mask_forward_train(x, sampling_results,
  74. bbox_results['bbox_feats'],
  75. gt_masks, img_metas)
  76. losses.update(mask_results['loss_mask'])
  77. return losses
  78. def _bbox_forward(self, x, rois):
  79. """Box forward function used in both training and testing."""
  80. # TODO: a more flexible way to decide which feature maps to use
  81. bbox_feats = self.bbox_roi_extractor(
  82. x[:self.bbox_roi_extractor.num_inputs], rois)
  83. if self.with_shared_head:
  84. bbox_feats = self.shared_head(bbox_feats)
  85. cls_score, bbox_pred = self.bbox_head(bbox_feats)
  86. bbox_results = dict(
  87. cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats)
  88. return bbox_results
  89. def _bbox_forward_train(self,
  90. x,
  91. sampling_results,
  92. gt_bboxes,
  93. gt_labels,
  94. img_metas,
  95. neg_label_weights=None):
  96. """Run forward function and calculate loss for box head in training."""
  97. rois = bbox2roi([res.bboxes for res in sampling_results])
  98. bbox_results = self._bbox_forward(x, rois)
  99. bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes,
  100. gt_labels, self.train_cfg)
  101. # neg_label_weights obtained by sampler is image-wise, mapping back to
  102. # the corresponding location in label weights
  103. if neg_label_weights[0] is not None:
  104. label_weights = bbox_targets[1]
  105. cur_num_rois = 0
  106. for i in range(len(sampling_results)):
  107. num_pos = sampling_results[i].pos_inds.size(0)
  108. num_neg = sampling_results[i].neg_inds.size(0)
  109. label_weights[cur_num_rois + num_pos:cur_num_rois + num_pos +
  110. num_neg] = neg_label_weights[i]
  111. cur_num_rois += num_pos + num_neg
  112. cls_score = bbox_results['cls_score']
  113. bbox_pred = bbox_results['bbox_pred']
  114. # Apply ISR-P
  115. isr_cfg = self.train_cfg.get('isr', None)
  116. if isr_cfg is not None:
  117. bbox_targets = isr_p(
  118. cls_score,
  119. bbox_pred,
  120. bbox_targets,
  121. rois,
  122. sampling_results,
  123. self.bbox_head.loss_cls,
  124. self.bbox_head.bbox_coder,
  125. **isr_cfg,
  126. num_class=self.bbox_head.num_classes)
  127. loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, rois,
  128. *bbox_targets)
  129. # Add CARL Loss
  130. carl_cfg = self.train_cfg.get('carl', None)
  131. if carl_cfg is not None:
  132. loss_carl = carl_loss(
  133. cls_score,
  134. bbox_targets[0],
  135. bbox_pred,
  136. bbox_targets[2],
  137. self.bbox_head.loss_bbox,
  138. **carl_cfg,
  139. num_class=self.bbox_head.num_classes)
  140. loss_bbox.update(loss_carl)
  141. bbox_results.update(loss_bbox=loss_bbox)
  142. return bbox_results

No Description

Contributors (3)