You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

focal_loss.py 7.8 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
  6. from ..builder import LOSSES
  7. from .utils import weight_reduce_loss
  8. # This method is only for debugging
  9. def py_sigmoid_focal_loss(pred,
  10. target,
  11. weight=None,
  12. gamma=2.0,
  13. alpha=0.25,
  14. reduction='mean',
  15. avg_factor=None):
  16. """PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
  17. Args:
  18. pred (torch.Tensor): The prediction with shape (N, C), C is the
  19. number of classes
  20. target (torch.Tensor): The learning label of the prediction.
  21. weight (torch.Tensor, optional): Sample-wise loss weight.
  22. gamma (float, optional): The gamma for calculating the modulating
  23. factor. Defaults to 2.0.
  24. alpha (float, optional): A balanced form for Focal Loss.
  25. Defaults to 0.25.
  26. reduction (str, optional): The method used to reduce the loss into
  27. a scalar. Defaults to 'mean'.
  28. avg_factor (int, optional): Average factor that is used to average
  29. the loss. Defaults to None.
  30. """
  31. pred_sigmoid = pred.sigmoid()
  32. target = target.type_as(pred)
  33. pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
  34. focal_weight = (alpha * target + (1 - alpha) *
  35. (1 - target)) * pt.pow(gamma)
  36. loss = F.binary_cross_entropy_with_logits(
  37. pred, target, reduction='none') * focal_weight
  38. if weight is not None:
  39. if weight.shape != loss.shape:
  40. if weight.size(0) == loss.size(0):
  41. # For most cases, weight is of shape (num_priors, ),
  42. # which means it does not have the second axis num_class
  43. weight = weight.view(-1, 1)
  44. else:
  45. # Sometimes, weight per anchor per class is also needed. e.g.
  46. # in FSAF. But it may be flattened of shape
  47. # (num_priors x num_class, ), while loss is still of shape
  48. # (num_priors, num_class).
  49. assert weight.numel() == loss.numel()
  50. weight = weight.view(loss.size(0), -1)
  51. assert weight.ndim == loss.ndim
  52. loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
  53. return loss
  54. def sigmoid_focal_loss(pred,
  55. target,
  56. weight=None,
  57. gamma=2.0,
  58. alpha=0.25,
  59. reduction='mean',
  60. avg_factor=None):
  61. r"""A warpper of cuda version `Focal Loss
  62. <https://arxiv.org/abs/1708.02002>`_.
  63. Args:
  64. pred (torch.Tensor): The prediction with shape (N, C), C is the number
  65. of classes.
  66. target (torch.Tensor): The learning label of the prediction.
  67. weight (torch.Tensor, optional): Sample-wise loss weight.
  68. gamma (float, optional): The gamma for calculating the modulating
  69. factor. Defaults to 2.0.
  70. alpha (float, optional): A balanced form for Focal Loss.
  71. Defaults to 0.25.
  72. reduction (str, optional): The method used to reduce the loss into
  73. a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
  74. avg_factor (int, optional): Average factor that is used to average
  75. the loss. Defaults to None.
  76. """
  77. # Function.apply does not accept keyword arguments, so the decorator
  78. # "weighted_loss" is not applicable
  79. loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma,
  80. alpha, None, 'none')
  81. loss_batch = loss.clone().detach()
  82. if weight is not None:
  83. if weight.shape != loss.shape:
  84. if weight.size(0) == loss.size(0):
  85. # For most cases, weight is of shape (num_priors, ),
  86. # which means it does not have the second axis num_class
  87. weight = weight.view(-1, 1)
  88. else:
  89. # Sometimes, weight per anchor per class is also needed. e.g.
  90. # in FSAF. But it may be flattened of shape
  91. # (num_priors x num_class, ), while loss is still of shape
  92. # (num_priors, num_class).
  93. assert weight.numel() == loss.numel()
  94. weight = weight.view(loss.size(0), -1)
  95. assert weight.ndim == loss.ndim
  96. if weight is not None:
  97. loss_batch = loss_batch * weight
  98. loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
  99. return loss, loss_batch
  100. @LOSSES.register_module()
  101. class FocalLoss(nn.Module):
  102. def __init__(self,
  103. use_sigmoid=True,
  104. gamma=2.0,
  105. alpha=0.25,
  106. reduction='mean',
  107. loss_weight=1.0):
  108. """`Focal Loss <https://arxiv.org/abs/1708.02002>`_
  109. Args:
  110. use_sigmoid (bool, optional): Whether to the prediction is
  111. used for sigmoid or softmax. Defaults to True.
  112. gamma (float, optional): The gamma for calculating the modulating
  113. factor. Defaults to 2.0.
  114. alpha (float, optional): A balanced form for Focal Loss.
  115. Defaults to 0.25.
  116. reduction (str, optional): The method used to reduce the loss into
  117. a scalar. Defaults to 'mean'. Options are "none", "mean" and
  118. "sum".
  119. loss_weight (float, optional): Weight of loss. Defaults to 1.0.
  120. """
  121. super(FocalLoss, self).__init__()
  122. assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
  123. self.use_sigmoid = use_sigmoid
  124. self.gamma = gamma
  125. self.alpha = alpha
  126. self.reduction = reduction
  127. self.loss_weight = loss_weight
  128. def forward(self,
  129. pred,
  130. target,
  131. weight=None,
  132. avg_factor=None,
  133. reduction_override=None):
  134. """Forward function.
  135. Args:
  136. pred (torch.Tensor): The prediction.
  137. target (torch.Tensor): The learning label of the prediction.
  138. weight (torch.Tensor, optional): The weight of loss for each
  139. prediction. Defaults to None.
  140. avg_factor (int, optional): Average factor that is used to average
  141. the loss. Defaults to None.
  142. reduction_override (str, optional): The reduction method used to
  143. override the original reduction method of the loss.
  144. Options are "none", "mean" and "sum".
  145. Returns:
  146. torch.Tensor: The calculated loss
  147. """
  148. assert reduction_override in (None, 'none', 'mean', 'sum')
  149. reduction = (
  150. reduction_override if reduction_override else self.reduction)
  151. if self.use_sigmoid:
  152. if torch.cuda.is_available() and pred.is_cuda:
  153. calculate_loss_func = sigmoid_focal_loss
  154. else:
  155. num_classes = pred.size(1)
  156. target = F.one_hot(target, num_classes=num_classes + 1)
  157. target = target[:, :num_classes]
  158. calculate_loss_func = py_sigmoid_focal_loss
  159. loss_cls, loss_batch = calculate_loss_func(
  160. pred,
  161. target,
  162. weight,
  163. gamma=self.gamma,
  164. alpha=self.alpha,
  165. reduction=reduction,
  166. avg_factor=avg_factor)
  167. else:
  168. raise NotImplementedError
  169. return self.loss_weight *loss_cls, self.loss_weight *loss_batch

No Description

Contributors (3)