You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pisa_retinanet_head.py 6.3 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. from mmcv.runner import force_fp32
  4. from mmdet.core import images_to_levels
  5. from ..builder import HEADS
  6. from ..losses import carl_loss, isr_p
  7. from .retina_head import RetinaHead
  8. @HEADS.register_module()
  9. class PISARetinaHead(RetinaHead):
  10. """PISA Retinanet Head.
  11. The head owns the same structure with Retinanet Head, but differs in two
  12. aspects:
  13. 1. Importance-based Sample Reweighting Positive (ISR-P) is applied to
  14. change the positive loss weights.
  15. 2. Classification-aware regression loss is adopted as a third loss.
  16. """
  17. @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
  18. def loss(self,
  19. cls_scores,
  20. bbox_preds,
  21. gt_bboxes,
  22. gt_labels,
  23. img_metas,
  24. gt_bboxes_ignore=None):
  25. """Compute losses of the head.
  26. Args:
  27. cls_scores (list[Tensor]): Box scores for each scale level
  28. Has shape (N, num_anchors * num_classes, H, W)
  29. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  30. level with shape (N, num_anchors * 4, H, W)
  31. gt_bboxes (list[Tensor]): Ground truth bboxes of each image
  32. with shape (num_obj, 4).
  33. gt_labels (list[Tensor]): Ground truth labels of each image
  34. with shape (num_obj, 4).
  35. img_metas (list[dict]): Meta information of each image, e.g.,
  36. image size, scaling factor, etc.
  37. gt_bboxes_ignore (list[Tensor]): Ignored gt bboxes of each image.
  38. Default: None.
  39. Returns:
  40. dict: Loss dict, comprise classification loss, regression loss and
  41. carl loss.
  42. """
  43. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  44. assert len(featmap_sizes) == self.prior_generator.num_levels
  45. device = cls_scores[0].device
  46. anchor_list, valid_flag_list = self.get_anchors(
  47. featmap_sizes, img_metas, device=device)
  48. label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
  49. cls_reg_targets = self.get_targets(
  50. anchor_list,
  51. valid_flag_list,
  52. gt_bboxes,
  53. img_metas,
  54. gt_bboxes_ignore_list=gt_bboxes_ignore,
  55. gt_labels_list=gt_labels,
  56. label_channels=label_channels,
  57. return_sampling_results=True)
  58. if cls_reg_targets is None:
  59. return None
  60. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  61. num_total_pos, num_total_neg, sampling_results_list) = cls_reg_targets
  62. num_total_samples = (
  63. num_total_pos + num_total_neg if self.sampling else num_total_pos)
  64. # anchor number of multi levels
  65. num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
  66. # concat all level anchors and flags to a single tensor
  67. concat_anchor_list = []
  68. for i in range(len(anchor_list)):
  69. concat_anchor_list.append(torch.cat(anchor_list[i]))
  70. all_anchor_list = images_to_levels(concat_anchor_list,
  71. num_level_anchors)
  72. num_imgs = len(img_metas)
  73. flatten_cls_scores = [
  74. cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, label_channels)
  75. for cls_score in cls_scores
  76. ]
  77. flatten_cls_scores = torch.cat(
  78. flatten_cls_scores, dim=1).reshape(-1,
  79. flatten_cls_scores[0].size(-1))
  80. flatten_bbox_preds = [
  81. bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
  82. for bbox_pred in bbox_preds
  83. ]
  84. flatten_bbox_preds = torch.cat(
  85. flatten_bbox_preds, dim=1).view(-1, flatten_bbox_preds[0].size(-1))
  86. flatten_labels = torch.cat(labels_list, dim=1).reshape(-1)
  87. flatten_label_weights = torch.cat(
  88. label_weights_list, dim=1).reshape(-1)
  89. flatten_anchors = torch.cat(all_anchor_list, dim=1).reshape(-1, 4)
  90. flatten_bbox_targets = torch.cat(
  91. bbox_targets_list, dim=1).reshape(-1, 4)
  92. flatten_bbox_weights = torch.cat(
  93. bbox_weights_list, dim=1).reshape(-1, 4)
  94. # Apply ISR-P
  95. isr_cfg = self.train_cfg.get('isr', None)
  96. if isr_cfg is not None:
  97. all_targets = (flatten_labels, flatten_label_weights,
  98. flatten_bbox_targets, flatten_bbox_weights)
  99. with torch.no_grad():
  100. all_targets = isr_p(
  101. flatten_cls_scores,
  102. flatten_bbox_preds,
  103. all_targets,
  104. flatten_anchors,
  105. sampling_results_list,
  106. bbox_coder=self.bbox_coder,
  107. loss_cls=self.loss_cls,
  108. num_class=self.num_classes,
  109. **self.train_cfg.isr)
  110. (flatten_labels, flatten_label_weights, flatten_bbox_targets,
  111. flatten_bbox_weights) = all_targets
  112. # For convenience we compute loss once instead separating by fpn level,
  113. # so that we don't need to separate the weights by level again.
  114. # The result should be the same
  115. losses_cls = self.loss_cls(
  116. flatten_cls_scores,
  117. flatten_labels,
  118. flatten_label_weights,
  119. avg_factor=num_total_samples)
  120. losses_bbox = self.loss_bbox(
  121. flatten_bbox_preds,
  122. flatten_bbox_targets,
  123. flatten_bbox_weights,
  124. avg_factor=num_total_samples)
  125. loss_dict = dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
  126. # CARL Loss
  127. carl_cfg = self.train_cfg.get('carl', None)
  128. if carl_cfg is not None:
  129. loss_carl = carl_loss(
  130. flatten_cls_scores,
  131. flatten_labels,
  132. flatten_bbox_preds,
  133. flatten_bbox_targets,
  134. self.loss_bbox,
  135. **self.train_cfg.carl,
  136. avg_factor=num_total_pos,
  137. sigmoid=True,
  138. num_class=self.num_classes)
  139. loss_dict.update(loss_carl)
  140. return loss_dict

No Description

Contributors (3)