You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pisa_ssd_head.py 5.6 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. from mmdet.core import multi_apply
  4. from ..builder import HEADS
  5. from ..losses import CrossEntropyLoss, SmoothL1Loss, carl_loss, isr_p
  6. from .ssd_head import SSDHead
  7. # TODO: add loss evaluator for SSD
  8. @HEADS.register_module()
  9. class PISASSDHead(SSDHead):
  10. def loss(self,
  11. cls_scores,
  12. bbox_preds,
  13. gt_bboxes,
  14. gt_labels,
  15. img_metas,
  16. gt_bboxes_ignore=None):
  17. """Compute losses of the head.
  18. Args:
  19. cls_scores (list[Tensor]): Box scores for each scale level
  20. Has shape (N, num_anchors * num_classes, H, W)
  21. bbox_preds (list[Tensor]): Box energies / deltas for each scale
  22. level with shape (N, num_anchors * 4, H, W)
  23. gt_bboxes (list[Tensor]): Ground truth bboxes of each image
  24. with shape (num_obj, 4).
  25. gt_labels (list[Tensor]): Ground truth labels of each image
  26. with shape (num_obj, 4).
  27. img_metas (list[dict]): Meta information of each image, e.g.,
  28. image size, scaling factor, etc.
  29. gt_bboxes_ignore (list[Tensor]): Ignored gt bboxes of each image.
  30. Default: None.
  31. Returns:
  32. dict: Loss dict, comprise classification loss regression loss and
  33. carl loss.
  34. """
  35. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  36. assert len(featmap_sizes) == self.prior_generator.num_levels
  37. device = cls_scores[0].device
  38. anchor_list, valid_flag_list = self.get_anchors(
  39. featmap_sizes, img_metas, device=device)
  40. cls_reg_targets = self.get_targets(
  41. anchor_list,
  42. valid_flag_list,
  43. gt_bboxes,
  44. img_metas,
  45. gt_bboxes_ignore_list=gt_bboxes_ignore,
  46. gt_labels_list=gt_labels,
  47. label_channels=1,
  48. unmap_outputs=False,
  49. return_sampling_results=True)
  50. if cls_reg_targets is None:
  51. return None
  52. (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
  53. num_total_pos, num_total_neg, sampling_results_list) = cls_reg_targets
  54. num_images = len(img_metas)
  55. all_cls_scores = torch.cat([
  56. s.permute(0, 2, 3, 1).reshape(
  57. num_images, -1, self.cls_out_channels) for s in cls_scores
  58. ], 1)
  59. all_labels = torch.cat(labels_list, -1).view(num_images, -1)
  60. all_label_weights = torch.cat(label_weights_list,
  61. -1).view(num_images, -1)
  62. all_bbox_preds = torch.cat([
  63. b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
  64. for b in bbox_preds
  65. ], -2)
  66. all_bbox_targets = torch.cat(bbox_targets_list,
  67. -2).view(num_images, -1, 4)
  68. all_bbox_weights = torch.cat(bbox_weights_list,
  69. -2).view(num_images, -1, 4)
  70. # concat all level anchors to a single tensor
  71. all_anchors = []
  72. for i in range(num_images):
  73. all_anchors.append(torch.cat(anchor_list[i]))
  74. isr_cfg = self.train_cfg.get('isr', None)
  75. all_targets = (all_labels.view(-1), all_label_weights.view(-1),
  76. all_bbox_targets.view(-1,
  77. 4), all_bbox_weights.view(-1, 4))
  78. # apply ISR-P
  79. if isr_cfg is not None:
  80. all_targets = isr_p(
  81. all_cls_scores.view(-1, all_cls_scores.size(-1)),
  82. all_bbox_preds.view(-1, 4),
  83. all_targets,
  84. torch.cat(all_anchors),
  85. sampling_results_list,
  86. loss_cls=CrossEntropyLoss(),
  87. bbox_coder=self.bbox_coder,
  88. **self.train_cfg.isr,
  89. num_class=self.num_classes)
  90. (new_labels, new_label_weights, new_bbox_targets,
  91. new_bbox_weights) = all_targets
  92. all_labels = new_labels.view(all_labels.shape)
  93. all_label_weights = new_label_weights.view(all_label_weights.shape)
  94. all_bbox_targets = new_bbox_targets.view(all_bbox_targets.shape)
  95. all_bbox_weights = new_bbox_weights.view(all_bbox_weights.shape)
  96. # add CARL loss
  97. carl_loss_cfg = self.train_cfg.get('carl', None)
  98. if carl_loss_cfg is not None:
  99. loss_carl = carl_loss(
  100. all_cls_scores.view(-1, all_cls_scores.size(-1)),
  101. all_targets[0],
  102. all_bbox_preds.view(-1, 4),
  103. all_targets[2],
  104. SmoothL1Loss(beta=1.),
  105. **self.train_cfg.carl,
  106. avg_factor=num_total_pos,
  107. num_class=self.num_classes)
  108. # check NaN and Inf
  109. assert torch.isfinite(all_cls_scores).all().item(), \
  110. 'classification scores become infinite or NaN!'
  111. assert torch.isfinite(all_bbox_preds).all().item(), \
  112. 'bbox predications become infinite or NaN!'
  113. losses_cls, losses_bbox = multi_apply(
  114. self.loss_single,
  115. all_cls_scores,
  116. all_bbox_preds,
  117. all_anchors,
  118. all_labels,
  119. all_label_weights,
  120. all_bbox_targets,
  121. all_bbox_weights,
  122. num_total_samples=num_total_pos)
  123. loss_dict = dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
  124. if carl_loss_cfg is not None:
  125. loss_dict.update(loss_carl)
  126. return loss_dict

No Description

Contributors (3)