You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gfocal_loss.py 7.5 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import mmcv
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from ..builder import LOSSES
  6. from .utils import weighted_loss
  7. @mmcv.jit(derivate=True, coderize=True)
  8. @weighted_loss
  9. def quality_focal_loss(pred, target, beta=2.0):
  10. r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning
  11. Qualified and Distributed Bounding Boxes for Dense Object Detection
  12. <https://arxiv.org/abs/2006.04388>`_.
  13. Args:
  14. pred (torch.Tensor): Predicted joint representation of classification
  15. and quality (IoU) estimation with shape (N, C), C is the number of
  16. classes.
  17. target (tuple([torch.Tensor])): Target category label with shape (N,)
  18. and target quality label with shape (N,).
  19. beta (float): The beta parameter for calculating the modulating factor.
  20. Defaults to 2.0.
  21. Returns:
  22. torch.Tensor: Loss tensor with shape (N,).
  23. """
  24. assert len(target) == 2, """target for QFL must be a tuple of two elements,
  25. including category label and quality label, respectively"""
  26. # label denotes the category id, score denotes the quality score
  27. label, score = target
  28. # negatives are supervised by 0 quality score
  29. pred_sigmoid = pred.sigmoid()
  30. scale_factor = pred_sigmoid
  31. zerolabel = scale_factor.new_zeros(pred.shape)
  32. loss = F.binary_cross_entropy_with_logits(
  33. pred, zerolabel, reduction='none') * scale_factor.pow(beta)
  34. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  35. bg_class_ind = pred.size(1)
  36. pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
  37. pos_label = label[pos].long()
  38. # positives are supervised by bbox quality (IoU) score
  39. scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
  40. loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
  41. pred[pos, pos_label], score[pos],
  42. reduction='none') * scale_factor.abs().pow(beta)
  43. loss = loss.sum(dim=1, keepdim=False)
  44. return loss
  45. @mmcv.jit(derivate=True, coderize=True)
  46. @weighted_loss
  47. def distribution_focal_loss(pred, label):
  48. r"""Distribution Focal Loss (DFL) is from `Generalized Focal Loss: Learning
  49. Qualified and Distributed Bounding Boxes for Dense Object Detection
  50. <https://arxiv.org/abs/2006.04388>`_.
  51. Args:
  52. pred (torch.Tensor): Predicted general distribution of bounding boxes
  53. (before softmax) with shape (N, n+1), n is the max value of the
  54. integral set `{0, ..., n}` in paper.
  55. label (torch.Tensor): Target distance label for bounding boxes with
  56. shape (N,).
  57. Returns:
  58. torch.Tensor: Loss tensor with shape (N,).
  59. """
  60. dis_left = label.long()
  61. dis_right = dis_left + 1
  62. weight_left = dis_right.float() - label
  63. weight_right = label - dis_left.float()
  64. loss = F.cross_entropy(pred, dis_left, reduction='none') * weight_left \
  65. + F.cross_entropy(pred, dis_right, reduction='none') * weight_right
  66. return loss
  67. @LOSSES.register_module()
  68. class QualityFocalLoss(nn.Module):
  69. r"""Quality Focal Loss (QFL) is a variant of `Generalized Focal Loss:
  70. Learning Qualified and Distributed Bounding Boxes for Dense Object
  71. Detection <https://arxiv.org/abs/2006.04388>`_.
  72. Args:
  73. use_sigmoid (bool): Whether sigmoid operation is conducted in QFL.
  74. Defaults to True.
  75. beta (float): The beta parameter for calculating the modulating factor.
  76. Defaults to 2.0.
  77. reduction (str): Options are "none", "mean" and "sum".
  78. loss_weight (float): Loss weight of current loss.
  79. """
  80. def __init__(self,
  81. use_sigmoid=True,
  82. beta=2.0,
  83. reduction='mean',
  84. loss_weight=1.0):
  85. super(QualityFocalLoss, self).__init__()
  86. assert use_sigmoid is True, 'Only sigmoid in QFL supported now.'
  87. self.use_sigmoid = use_sigmoid
  88. self.beta = beta
  89. self.reduction = reduction
  90. self.loss_weight = loss_weight
  91. def forward(self,
  92. pred,
  93. target,
  94. weight=None,
  95. avg_factor=None,
  96. reduction_override=None):
  97. """Forward function.
  98. Args:
  99. pred (torch.Tensor): Predicted joint representation of
  100. classification and quality (IoU) estimation with shape (N, C),
  101. C is the number of classes.
  102. target (tuple([torch.Tensor])): Target category label with shape
  103. (N,) and target quality label with shape (N,).
  104. weight (torch.Tensor, optional): The weight of loss for each
  105. prediction. Defaults to None.
  106. avg_factor (int, optional): Average factor that is used to average
  107. the loss. Defaults to None.
  108. reduction_override (str, optional): The reduction method used to
  109. override the original reduction method of the loss.
  110. Defaults to None.
  111. """
  112. assert reduction_override in (None, 'none', 'mean', 'sum')
  113. reduction = (
  114. reduction_override if reduction_override else self.reduction)
  115. if self.use_sigmoid:
  116. loss_cls = self.loss_weight * quality_focal_loss(
  117. pred,
  118. target,
  119. weight,
  120. beta=self.beta,
  121. reduction=reduction,
  122. avg_factor=avg_factor)
  123. else:
  124. raise NotImplementedError
  125. return loss_cls
  126. @LOSSES.register_module()
  127. class DistributionFocalLoss(nn.Module):
  128. r"""Distribution Focal Loss (DFL) is a variant of `Generalized Focal Loss:
  129. Learning Qualified and Distributed Bounding Boxes for Dense Object
  130. Detection <https://arxiv.org/abs/2006.04388>`_.
  131. Args:
  132. reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
  133. loss_weight (float): Loss weight of current loss.
  134. """
  135. def __init__(self, reduction='mean', loss_weight=1.0):
  136. super(DistributionFocalLoss, self).__init__()
  137. self.reduction = reduction
  138. self.loss_weight = loss_weight
  139. def forward(self,
  140. pred,
  141. target,
  142. weight=None,
  143. avg_factor=None,
  144. reduction_override=None):
  145. """Forward function.
  146. Args:
  147. pred (torch.Tensor): Predicted general distribution of bounding
  148. boxes (before softmax) with shape (N, n+1), n is the max value
  149. of the integral set `{0, ..., n}` in paper.
  150. target (torch.Tensor): Target distance label for bounding boxes
  151. with shape (N,).
  152. weight (torch.Tensor, optional): The weight of loss for each
  153. prediction. Defaults to None.
  154. avg_factor (int, optional): Average factor that is used to average
  155. the loss. Defaults to None.
  156. reduction_override (str, optional): The reduction method used to
  157. override the original reduction method of the loss.
  158. Defaults to None.
  159. """
  160. assert reduction_override in (None, 'none', 'mean', 'sum')
  161. reduction = (
  162. reduction_override if reduction_override else self.reduction)
  163. loss_cls = self.loss_weight * distribution_focal_loss(
  164. pred, target, weight, reduction=reduction, avg_factor=avg_factor)
  165. return loss_cls

No Description

Contributors (2)