You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

balanced_l1_loss.py 4.3 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import mmcv
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. from ..builder import LOSSES
  7. from .utils import weighted_loss
  8. @mmcv.jit(derivate=True, coderize=True)
  9. @weighted_loss
  10. def balanced_l1_loss(pred,
  11. target,
  12. beta=1.0,
  13. alpha=0.5,
  14. gamma=1.5,
  15. reduction='mean'):
  16. """Calculate balanced L1 loss.
  17. Please see the `Libra R-CNN <https://arxiv.org/pdf/1904.02701.pdf>`_
  18. Args:
  19. pred (torch.Tensor): The prediction with shape (N, 4).
  20. target (torch.Tensor): The learning target of the prediction with
  21. shape (N, 4).
  22. beta (float): The loss is a piecewise function of prediction and target
  23. and ``beta`` serves as a threshold for the difference between the
  24. prediction and target. Defaults to 1.0.
  25. alpha (float): The denominator ``alpha`` in the balanced L1 loss.
  26. Defaults to 0.5.
  27. gamma (float): The ``gamma`` in the balanced L1 loss.
  28. Defaults to 1.5.
  29. reduction (str, optional): The method that reduces the loss to a
  30. scalar. Options are "none", "mean" and "sum".
  31. Returns:
  32. torch.Tensor: The calculated loss
  33. """
  34. assert beta > 0
  35. if target.numel() == 0:
  36. return pred.sum() * 0
  37. assert pred.size() == target.size()
  38. diff = torch.abs(pred - target)
  39. b = np.e**(gamma / alpha) - 1
  40. loss = torch.where(
  41. diff < beta, alpha / b *
  42. (b * diff + 1) * torch.log(b * diff / beta + 1) - alpha * diff,
  43. gamma * diff + gamma / b - alpha * beta)
  44. return loss
  45. @LOSSES.register_module()
  46. class BalancedL1Loss(nn.Module):
  47. """Balanced L1 Loss.
  48. arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019)
  49. Args:
  50. alpha (float): The denominator ``alpha`` in the balanced L1 loss.
  51. Defaults to 0.5.
  52. gamma (float): The ``gamma`` in the balanced L1 loss. Defaults to 1.5.
  53. beta (float, optional): The loss is a piecewise function of prediction
  54. and target. ``beta`` serves as a threshold for the difference
  55. between the prediction and target. Defaults to 1.0.
  56. reduction (str, optional): The method that reduces the loss to a
  57. scalar. Options are "none", "mean" and "sum".
  58. loss_weight (float, optional): The weight of the loss. Defaults to 1.0
  59. """
  60. def __init__(self,
  61. alpha=0.5,
  62. gamma=1.5,
  63. beta=1.0,
  64. reduction='mean',
  65. loss_weight=1.0):
  66. super(BalancedL1Loss, self).__init__()
  67. self.alpha = alpha
  68. self.gamma = gamma
  69. self.beta = beta
  70. self.reduction = reduction
  71. self.loss_weight = loss_weight
  72. def forward(self,
  73. pred,
  74. target,
  75. weight=None,
  76. avg_factor=None,
  77. reduction_override=None,
  78. **kwargs):
  79. """Forward function of loss.
  80. Args:
  81. pred (torch.Tensor): The prediction with shape (N, 4).
  82. target (torch.Tensor): The learning target of the prediction with
  83. shape (N, 4).
  84. weight (torch.Tensor, optional): Sample-wise loss weight with
  85. shape (N, ).
  86. avg_factor (int, optional): Average factor that is used to average
  87. the loss. Defaults to None.
  88. reduction_override (str, optional): The reduction method used to
  89. override the original reduction method of the loss.
  90. Options are "none", "mean" and "sum".
  91. Returns:
  92. torch.Tensor: The calculated loss
  93. """
  94. assert reduction_override in (None, 'none', 'mean', 'sum')
  95. reduction = (
  96. reduction_override if reduction_override else self.reduction)
  97. loss_bbox = self.loss_weight * balanced_l1_loss(
  98. pred,
  99. target,
  100. weight,
  101. alpha=self.alpha,
  102. gamma=self.gamma,
  103. beta=self.beta,
  104. reduction=reduction,
  105. avg_factor=avg_factor,
  106. **kwargs)
  107. return loss_bbox

No Description

Contributors (2)