You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pisa_loss.py 7.2 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import mmcv
  3. import torch
  4. from mmdet.core import bbox_overlaps
  5. @mmcv.jit(derivate=True, coderize=True)
  6. def isr_p(cls_score,
  7. bbox_pred,
  8. bbox_targets,
  9. rois,
  10. sampling_results,
  11. loss_cls,
  12. bbox_coder,
  13. k=2,
  14. bias=0,
  15. num_class=80):
  16. """Importance-based Sample Reweighting (ISR_P), positive part.
  17. Args:
  18. cls_score (Tensor): Predicted classification scores.
  19. bbox_pred (Tensor): Predicted bbox deltas.
  20. bbox_targets (tuple[Tensor]): A tuple of bbox targets, the are
  21. labels, label_weights, bbox_targets, bbox_weights, respectively.
  22. rois (Tensor): Anchors (single_stage) in shape (n, 4) or RoIs
  23. (two_stage) in shape (n, 5).
  24. sampling_results (obj): Sampling results.
  25. loss_cls (func): Classification loss func of the head.
  26. bbox_coder (obj): BBox coder of the head.
  27. k (float): Power of the non-linear mapping.
  28. bias (float): Shift of the non-linear mapping.
  29. num_class (int): Number of classes, default: 80.
  30. Return:
  31. tuple([Tensor]): labels, imp_based_label_weights, bbox_targets,
  32. bbox_target_weights
  33. """
  34. labels, label_weights, bbox_targets, bbox_weights = bbox_targets
  35. pos_label_inds = ((labels >= 0) &
  36. (labels < num_class)).nonzero().reshape(-1)
  37. pos_labels = labels[pos_label_inds]
  38. # if no positive samples, return the original targets
  39. num_pos = float(pos_label_inds.size(0))
  40. if num_pos == 0:
  41. return labels, label_weights, bbox_targets, bbox_weights
  42. # merge pos_assigned_gt_inds of per image to a single tensor
  43. gts = list()
  44. last_max_gt = 0
  45. for i in range(len(sampling_results)):
  46. gt_i = sampling_results[i].pos_assigned_gt_inds
  47. gts.append(gt_i + last_max_gt)
  48. if len(gt_i) != 0:
  49. last_max_gt = gt_i.max() + 1
  50. gts = torch.cat(gts)
  51. assert len(gts) == num_pos
  52. cls_score = cls_score.detach()
  53. bbox_pred = bbox_pred.detach()
  54. # For single stage detectors, rois here indicate anchors, in shape (N, 4)
  55. # For two stage detectors, rois are in shape (N, 5)
  56. if rois.size(-1) == 5:
  57. pos_rois = rois[pos_label_inds][:, 1:]
  58. else:
  59. pos_rois = rois[pos_label_inds]
  60. if bbox_pred.size(-1) > 4:
  61. bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4)
  62. pos_delta_pred = bbox_pred[pos_label_inds, pos_labels].view(-1, 4)
  63. else:
  64. pos_delta_pred = bbox_pred[pos_label_inds].view(-1, 4)
  65. # compute iou of the predicted bbox and the corresponding GT
  66. pos_delta_target = bbox_targets[pos_label_inds].view(-1, 4)
  67. pos_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_pred)
  68. target_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_target)
  69. ious = bbox_overlaps(pos_bbox_pred, target_bbox_pred, is_aligned=True)
  70. pos_imp_weights = label_weights[pos_label_inds]
  71. # Two steps to compute IoU-HLR. Samples are first sorted by IoU locally,
  72. # then sorted again within the same-rank group
  73. max_l_num = pos_labels.bincount().max()
  74. for label in pos_labels.unique():
  75. l_inds = (pos_labels == label).nonzero().view(-1)
  76. l_gts = gts[l_inds]
  77. for t in l_gts.unique():
  78. t_inds = l_inds[l_gts == t]
  79. t_ious = ious[t_inds]
  80. _, t_iou_rank_idx = t_ious.sort(descending=True)
  81. _, t_iou_rank = t_iou_rank_idx.sort()
  82. ious[t_inds] += max_l_num - t_iou_rank.float()
  83. l_ious = ious[l_inds]
  84. _, l_iou_rank_idx = l_ious.sort(descending=True)
  85. _, l_iou_rank = l_iou_rank_idx.sort() # IoU-HLR
  86. # linearly map HLR to label weights
  87. pos_imp_weights[l_inds] *= (max_l_num - l_iou_rank.float()) / max_l_num
  88. pos_imp_weights = (bias + pos_imp_weights * (1 - bias)).pow(k)
  89. # normalize to make the new weighted loss value equal to the original loss
  90. pos_loss_cls = loss_cls(
  91. cls_score[pos_label_inds], pos_labels, reduction_override='none')
  92. if pos_loss_cls.dim() > 1:
  93. ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds][:,
  94. None]
  95. new_pos_loss_cls = pos_loss_cls * pos_imp_weights[:, None]
  96. else:
  97. ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds]
  98. new_pos_loss_cls = pos_loss_cls * pos_imp_weights
  99. pos_loss_cls_ratio = ori_pos_loss_cls.sum() / new_pos_loss_cls.sum()
  100. pos_imp_weights = pos_imp_weights * pos_loss_cls_ratio
  101. label_weights[pos_label_inds] = pos_imp_weights
  102. bbox_targets = labels, label_weights, bbox_targets, bbox_weights
  103. return bbox_targets
  104. @mmcv.jit(derivate=True, coderize=True)
  105. def carl_loss(cls_score,
  106. labels,
  107. bbox_pred,
  108. bbox_targets,
  109. loss_bbox,
  110. k=1,
  111. bias=0.2,
  112. avg_factor=None,
  113. sigmoid=False,
  114. num_class=80):
  115. """Classification-Aware Regression Loss (CARL).
  116. Args:
  117. cls_score (Tensor): Predicted classification scores.
  118. labels (Tensor): Targets of classification.
  119. bbox_pred (Tensor): Predicted bbox deltas.
  120. bbox_targets (Tensor): Target of bbox regression.
  121. loss_bbox (func): Regression loss func of the head.
  122. bbox_coder (obj): BBox coder of the head.
  123. k (float): Power of the non-linear mapping.
  124. bias (float): Shift of the non-linear mapping.
  125. avg_factor (int): Average factor used in regression loss.
  126. sigmoid (bool): Activation of the classification score.
  127. num_class (int): Number of classes, default: 80.
  128. Return:
  129. dict: CARL loss dict.
  130. """
  131. pos_label_inds = ((labels >= 0) &
  132. (labels < num_class)).nonzero().reshape(-1)
  133. if pos_label_inds.numel() == 0:
  134. return dict(loss_carl=cls_score.sum()[None] * 0.)
  135. pos_labels = labels[pos_label_inds]
  136. # multiply pos_cls_score with the corresponding bbox weight
  137. # and remain gradient
  138. if sigmoid:
  139. pos_cls_score = cls_score.sigmoid()[pos_label_inds, pos_labels]
  140. else:
  141. pos_cls_score = cls_score.softmax(-1)[pos_label_inds, pos_labels]
  142. carl_loss_weights = (bias + (1 - bias) * pos_cls_score).pow(k)
  143. # normalize carl_loss_weight to make its sum equal to num positive
  144. num_pos = float(pos_cls_score.size(0))
  145. weight_ratio = num_pos / carl_loss_weights.sum()
  146. carl_loss_weights *= weight_ratio
  147. if avg_factor is None:
  148. avg_factor = bbox_targets.size(0)
  149. # if is class agnostic, bbox pred is in shape (N, 4)
  150. # otherwise, bbox pred is in shape (N, #classes, 4)
  151. if bbox_pred.size(-1) > 4:
  152. bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4)
  153. pos_bbox_preds = bbox_pred[pos_label_inds, pos_labels]
  154. else:
  155. pos_bbox_preds = bbox_pred[pos_label_inds]
  156. ori_loss_reg = loss_bbox(
  157. pos_bbox_preds,
  158. bbox_targets[pos_label_inds],
  159. reduction_override='none') / avg_factor
  160. loss_carl = (ori_loss_reg * carl_loss_weights[:, None]).sum()
  161. return dict(loss_carl=loss_carl[None])

No Description

Contributors (2)