You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ld_head.py 11 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. from mmcv.runner import force_fp32
  4. from mmdet.core import bbox_overlaps, multi_apply, reduce_mean
  5. from ..builder import HEADS, build_loss
  6. from .gfl_head import GFLHead
  7. @HEADS.register_module()
  8. class LDHead(GFLHead):
  9. """Localization distillation Head. (Short description)
  10. It utilizes the learned bbox distributions to transfer the localization
  11. dark knowledge from teacher to student. Original paper: `Localization
  12. Distillation for Object Detection. <https://arxiv.org/abs/2102.12252>`_
  13. Args:
  14. num_classes (int): Number of categories excluding the background
  15. category.
  16. in_channels (int): Number of channels in the input feature map.
  17. loss_ld (dict): Config of Localization Distillation Loss (LD),
  18. T is the temperature for distillation.
  19. """
  20. def __init__(self,
  21. num_classes,
  22. in_channels,
  23. loss_ld=dict(
  24. type='LocalizationDistillationLoss',
  25. loss_weight=0.25,
  26. T=10),
  27. **kwargs):
  28. super(LDHead, self).__init__(num_classes, in_channels, **kwargs)
  29. self.loss_ld = build_loss(loss_ld)
  30. def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights,
  31. bbox_targets, stride, soft_targets, num_total_samples):
  32. """Compute loss of a single scale level.
  33. Args:
  34. anchors (Tensor): Box reference for each scale level with shape
  35. (N, num_total_anchors, 4).
  36. cls_score (Tensor): Cls and quality joint scores for each scale
  37. level has shape (N, num_classes, H, W).
  38. bbox_pred (Tensor): Box distribution logits for each scale
  39. level with shape (N, 4*(n+1), H, W), n is max value of integral
  40. set.
  41. labels (Tensor): Labels of each anchors with shape
  42. (N, num_total_anchors).
  43. label_weights (Tensor): Label weights of each anchor with shape
  44. (N, num_total_anchors)
  45. bbox_targets (Tensor): BBox regression targets of each anchor
  46. weight shape (N, num_total_anchors, 4).
  47. stride (tuple): Stride in this scale level.
  48. num_total_samples (int): Number of positive samples that is
  49. reduced over all GPUs.
  50. Returns:
  51. dict[tuple, Tensor]: Loss components and weight targets.
  52. """
  53. assert stride[0] == stride[1], 'h stride is not equal to w stride!'
  54. anchors = anchors.reshape(-1, 4)
  55. cls_score = cls_score.permute(0, 2, 3,
  56. 1).reshape(-1, self.cls_out_channels)
  57. bbox_pred = bbox_pred.permute(0, 2, 3,
  58. 1).reshape(-1, 4 * (self.reg_max + 1))
  59. soft_targets = soft_targets.permute(0, 2, 3,
  60. 1).reshape(-1,
  61. 4 * (self.reg_max + 1))
  62. bbox_targets = bbox_targets.reshape(-1, 4)
  63. labels = labels.reshape(-1)
  64. label_weights = label_weights.reshape(-1)
  65. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  66. bg_class_ind = self.num_classes
  67. pos_inds = ((labels >= 0)
  68. & (labels < bg_class_ind)).nonzero().squeeze(1)
  69. score = label_weights.new_zeros(labels.shape)
  70. if len(pos_inds) > 0:
  71. pos_bbox_targets = bbox_targets[pos_inds]
  72. pos_bbox_pred = bbox_pred[pos_inds]
  73. pos_anchors = anchors[pos_inds]
  74. pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0]
  75. weight_targets = cls_score.detach().sigmoid()
  76. weight_targets = weight_targets.max(dim=1)[0][pos_inds]
  77. pos_bbox_pred_corners = self.integral(pos_bbox_pred)
  78. pos_decode_bbox_pred = self.bbox_coder.decode(
  79. pos_anchor_centers, pos_bbox_pred_corners)
  80. pos_decode_bbox_targets = pos_bbox_targets / stride[0]
  81. score[pos_inds] = bbox_overlaps(
  82. pos_decode_bbox_pred.detach(),
  83. pos_decode_bbox_targets,
  84. is_aligned=True)
  85. pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1)
  86. pos_soft_targets = soft_targets[pos_inds]
  87. soft_corners = pos_soft_targets.reshape(-1, self.reg_max + 1)
  88. target_corners = self.bbox_coder.encode(pos_anchor_centers,
  89. pos_decode_bbox_targets,
  90. self.reg_max).reshape(-1)
  91. # regression loss
  92. loss_bbox = self.loss_bbox(
  93. pos_decode_bbox_pred,
  94. pos_decode_bbox_targets,
  95. weight=weight_targets,
  96. avg_factor=1.0)
  97. # dfl loss
  98. loss_dfl = self.loss_dfl(
  99. pred_corners,
  100. target_corners,
  101. weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
  102. avg_factor=4.0)
  103. # ld loss
  104. loss_ld = self.loss_ld(
  105. pred_corners,
  106. soft_corners,
  107. weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
  108. avg_factor=4.0)
  109. else:
  110. loss_ld = bbox_pred.sum() * 0
  111. loss_bbox = bbox_pred.sum() * 0
  112. loss_dfl = bbox_pred.sum() * 0
  113. weight_targets = bbox_pred.new_tensor(0)
  114. # cls (qfl) loss
  115. loss_cls = self.loss_cls(
  116. cls_score, (labels, score),
  117. weight=label_weights,
  118. avg_factor=num_total_samples)
  119. return loss_cls, loss_bbox, loss_dfl, loss_ld, weight_targets.sum()
  120. def forward_train(self,
  121. x,
  122. out_teacher,
  123. img_metas,
  124. gt_bboxes,
  125. gt_labels=None,
  126. gt_bboxes_ignore=None,
  127. proposal_cfg=None,
  128. **kwargs):
  129. """
  130. Args:
  131. x (list[Tensor]): Features from FPN.
  132. img_metas (list[dict]): Meta information of each image, e.g.,
  133. image size, scaling factor, etc.
  134. gt_bboxes (Tensor): Ground truth bboxes of the image,
  135. shape (num_gts, 4).
  136. gt_labels (Tensor): Ground truth labels of each box,
  137. shape (num_gts,).
  138. gt_bboxes_ignore (Tensor): Ground truth bboxes to be
  139. ignored, shape (num_ignored_gts, 4).
  140. proposal_cfg (mmcv.Config): Test / postprocessing configuration,
  141. if None, test_cfg would be used
  142. Returns:
  143. tuple[dict, list]: The loss components and proposals of each image.
  144. - losses (dict[str, Tensor]): A dictionary of loss components.
  145. - proposal_list (list[Tensor]): Proposals of each image.
  146. """
  147. outs = self(x)
  148. soft_target = out_teacher[1]
  149. if gt_labels is None:
  150. loss_inputs = outs + (gt_bboxes, soft_target, img_metas)
  151. else:
  152. loss_inputs = outs + (gt_bboxes, gt_labels, soft_target, img_metas)
  153. losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
  154. if proposal_cfg is None:
  155. return losses
  156. else:
  157. proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg)
  158. return losses, proposal_list
  159. @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
  160. def loss(self,
  161. cls_scores,
  162. bbox_preds,
  163. gt_bboxes,
  164. gt_labels,
  165. soft_target,
  166. img_metas,
  167. gt_bboxes_ignore=None):
  168. """Compute losses of the head.
  169. Args:
  170. cls_scores (list[Tensor]): Cls and quality scores for each scale
  171. level has shape (N, num_classes, H, W).
  172. bbox_preds (list[Tensor]): Box distribution logits for each scale
  173. level with shape (N, 4*(n+1), H, W), n is max value of integral
  174. set.
  175. gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
  176. shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
  177. gt_labels (list[Tensor]): class indices corresponding to each box
  178. img_metas (list[dict]): Meta information of each image, e.g.,
  179. image size, scaling factor, etc.
  180. gt_bboxes_ignore (list[Tensor] | None): specify which bounding
  181. boxes can be ignored when computing the loss.
  182. Returns:
  183. dict[str, Tensor]: A dictionary of loss components.
  184. """
  185. featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
  186. assert len(featmap_sizes) == self.prior_generator.num_levels
  187. device = cls_scores[0].device
  188. anchor_list, valid_flag_list = self.get_anchors(
  189. featmap_sizes, img_metas, device=device)
  190. label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
  191. cls_reg_targets = self.get_targets(
  192. anchor_list,
  193. valid_flag_list,
  194. gt_bboxes,
  195. img_metas,
  196. gt_bboxes_ignore_list=gt_bboxes_ignore,
  197. gt_labels_list=gt_labels,
  198. label_channels=label_channels)
  199. if cls_reg_targets is None:
  200. return None
  201. (anchor_list, labels_list, label_weights_list, bbox_targets_list,
  202. bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
  203. num_total_samples = reduce_mean(
  204. torch.tensor(num_total_pos, dtype=torch.float,
  205. device=device)).item()
  206. num_total_samples = max(num_total_samples, 1.0)
  207. losses_cls, losses_bbox, losses_dfl, losses_ld, \
  208. avg_factor = multi_apply(
  209. self.loss_single,
  210. anchor_list,
  211. cls_scores,
  212. bbox_preds,
  213. labels_list,
  214. label_weights_list,
  215. bbox_targets_list,
  216. self.prior_generator.strides,
  217. soft_target,
  218. num_total_samples=num_total_samples)
  219. avg_factor = sum(avg_factor) + 1e-6
  220. avg_factor = reduce_mean(avg_factor).item()
  221. losses_bbox = [x / avg_factor for x in losses_bbox]
  222. losses_dfl = [x / avg_factor for x in losses_dfl]
  223. return dict(
  224. loss_cls=losses_cls,
  225. loss_bbox=losses_bbox,
  226. loss_dfl=losses_dfl,
  227. loss_ld=losses_ld)

No Description

Contributors (3)