You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gaussian_focal_loss.py 3.3 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import mmcv
  3. import torch.nn as nn
  4. from ..builder import LOSSES
  5. from .utils import weighted_loss
  6. @mmcv.jit(derivate=True, coderize=True)
  7. @weighted_loss
  8. def gaussian_focal_loss(pred, gaussian_target, alpha=2.0, gamma=4.0):
  9. """`Focal Loss <https://arxiv.org/abs/1708.02002>`_ for targets in gaussian
  10. distribution.
  11. Args:
  12. pred (torch.Tensor): The prediction.
  13. gaussian_target (torch.Tensor): The learning target of the prediction
  14. in gaussian distribution.
  15. alpha (float, optional): A balanced form for Focal Loss.
  16. Defaults to 2.0.
  17. gamma (float, optional): The gamma for calculating the modulating
  18. factor. Defaults to 4.0.
  19. """
  20. eps = 1e-12
  21. pos_weights = gaussian_target.eq(1)
  22. neg_weights = (1 - gaussian_target).pow(gamma)
  23. pos_loss = -(pred + eps).log() * (1 - pred).pow(alpha) * pos_weights
  24. neg_loss = -(1 - pred + eps).log() * pred.pow(alpha) * neg_weights
  25. return pos_loss + neg_loss
  26. @LOSSES.register_module()
  27. class GaussianFocalLoss(nn.Module):
  28. """GaussianFocalLoss is a variant of focal loss.
  29. More details can be found in the `paper
  30. <https://arxiv.org/abs/1808.01244>`_
  31. Code is modified from `kp_utils.py
  32. <https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/kp_utils.py#L152>`_ # noqa: E501
  33. Please notice that the target in GaussianFocalLoss is a gaussian heatmap,
  34. not 0/1 binary target.
  35. Args:
  36. alpha (float): Power of prediction.
  37. gamma (float): Power of target for negative samples.
  38. reduction (str): Options are "none", "mean" and "sum".
  39. loss_weight (float): Loss weight of current loss.
  40. """
  41. def __init__(self,
  42. alpha=2.0,
  43. gamma=4.0,
  44. reduction='mean',
  45. loss_weight=1.0):
  46. super(GaussianFocalLoss, self).__init__()
  47. self.alpha = alpha
  48. self.gamma = gamma
  49. self.reduction = reduction
  50. self.loss_weight = loss_weight
  51. def forward(self,
  52. pred,
  53. target,
  54. weight=None,
  55. avg_factor=None,
  56. reduction_override=None):
  57. """Forward function.
  58. Args:
  59. pred (torch.Tensor): The prediction.
  60. target (torch.Tensor): The learning target of the prediction
  61. in gaussian distribution.
  62. weight (torch.Tensor, optional): The weight of loss for each
  63. prediction. Defaults to None.
  64. avg_factor (int, optional): Average factor that is used to average
  65. the loss. Defaults to None.
  66. reduction_override (str, optional): The reduction method used to
  67. override the original reduction method of the loss.
  68. Defaults to None.
  69. """
  70. assert reduction_override in (None, 'none', 'mean', 'sum')
  71. reduction = (
  72. reduction_override if reduction_override else self.reduction)
  73. loss_reg = self.loss_weight * gaussian_focal_loss(
  74. pred,
  75. target,
  76. weight,
  77. alpha=self.alpha,
  78. gamma=self.gamma,
  79. reduction=reduction,
  80. avg_factor=avg_factor)
  81. return loss_reg

No Description

Contributors (2)