You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

varifocal_loss.py 5.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import mmcv
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from ..builder import LOSSES
  6. from .utils import weight_reduce_loss
  7. @mmcv.jit(derivate=True, coderize=True)
  8. def varifocal_loss(pred,
  9. target,
  10. weight=None,
  11. alpha=0.75,
  12. gamma=2.0,
  13. iou_weighted=True,
  14. reduction='mean',
  15. avg_factor=None):
  16. """`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
  17. Args:
  18. pred (torch.Tensor): The prediction with shape (N, C), C is the
  19. number of classes
  20. target (torch.Tensor): The learning target of the iou-aware
  21. classification score with shape (N, C), C is the number of classes.
  22. weight (torch.Tensor, optional): The weight of loss for each
  23. prediction. Defaults to None.
  24. alpha (float, optional): A balance factor for the negative part of
  25. Varifocal Loss, which is different from the alpha of Focal Loss.
  26. Defaults to 0.75.
  27. gamma (float, optional): The gamma for calculating the modulating
  28. factor. Defaults to 2.0.
  29. iou_weighted (bool, optional): Whether to weight the loss of the
  30. positive example with the iou target. Defaults to True.
  31. reduction (str, optional): The method used to reduce the loss into
  32. a scalar. Defaults to 'mean'. Options are "none", "mean" and
  33. "sum".
  34. avg_factor (int, optional): Average factor that is used to average
  35. the loss. Defaults to None.
  36. """
  37. # pred and target should be of the same size
  38. assert pred.size() == target.size()
  39. pred_sigmoid = pred.sigmoid()
  40. target = target.type_as(pred)
  41. if iou_weighted:
  42. focal_weight = target * (target > 0.0).float() + \
  43. alpha * (pred_sigmoid - target).abs().pow(gamma) * \
  44. (target <= 0.0).float()
  45. else:
  46. focal_weight = (target > 0.0).float() + \
  47. alpha * (pred_sigmoid - target).abs().pow(gamma) * \
  48. (target <= 0.0).float()
  49. loss = F.binary_cross_entropy_with_logits(
  50. pred, target, reduction='none') * focal_weight
  51. loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
  52. return loss
  53. @LOSSES.register_module()
  54. class VarifocalLoss(nn.Module):
  55. def __init__(self,
  56. use_sigmoid=True,
  57. alpha=0.75,
  58. gamma=2.0,
  59. iou_weighted=True,
  60. reduction='mean',
  61. loss_weight=1.0):
  62. """`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
  63. Args:
  64. use_sigmoid (bool, optional): Whether the prediction is
  65. used for sigmoid or softmax. Defaults to True.
  66. alpha (float, optional): A balance factor for the negative part of
  67. Varifocal Loss, which is different from the alpha of Focal
  68. Loss. Defaults to 0.75.
  69. gamma (float, optional): The gamma for calculating the modulating
  70. factor. Defaults to 2.0.
  71. iou_weighted (bool, optional): Whether to weight the loss of the
  72. positive examples with the iou target. Defaults to True.
  73. reduction (str, optional): The method used to reduce the loss into
  74. a scalar. Defaults to 'mean'. Options are "none", "mean" and
  75. "sum".
  76. loss_weight (float, optional): Weight of loss. Defaults to 1.0.
  77. """
  78. super(VarifocalLoss, self).__init__()
  79. assert use_sigmoid is True, \
  80. 'Only sigmoid varifocal loss supported now.'
  81. assert alpha >= 0.0
  82. self.use_sigmoid = use_sigmoid
  83. self.alpha = alpha
  84. self.gamma = gamma
  85. self.iou_weighted = iou_weighted
  86. self.reduction = reduction
  87. self.loss_weight = loss_weight
  88. def forward(self,
  89. pred,
  90. target,
  91. weight=None,
  92. avg_factor=None,
  93. reduction_override=None):
  94. """Forward function.
  95. Args:
  96. pred (torch.Tensor): The prediction.
  97. target (torch.Tensor): The learning target of the prediction.
  98. weight (torch.Tensor, optional): The weight of loss for each
  99. prediction. Defaults to None.
  100. avg_factor (int, optional): Average factor that is used to average
  101. the loss. Defaults to None.
  102. reduction_override (str, optional): The reduction method used to
  103. override the original reduction method of the loss.
  104. Options are "none", "mean" and "sum".
  105. Returns:
  106. torch.Tensor: The calculated loss
  107. """
  108. assert reduction_override in (None, 'none', 'mean', 'sum')
  109. reduction = (
  110. reduction_override if reduction_override else self.reduction)
  111. if self.use_sigmoid:
  112. loss_cls = self.loss_weight * varifocal_loss(
  113. pred,
  114. target,
  115. weight,
  116. alpha=self.alpha,
  117. gamma=self.gamma,
  118. iou_weighted=self.iou_weighted,
  119. reduction=reduction,
  120. avg_factor=avg_factor)
  121. else:
  122. raise NotImplementedError
  123. return loss_cls

No Description

Contributors (2)