You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

smooth_l1_loss.py 4.6 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import mmcv
  3. import torch
  4. import torch.nn as nn
  5. from ..builder import LOSSES
  6. from .utils import weighted_loss
  7. @mmcv.jit(derivate=True, coderize=True)
  8. @weighted_loss
  9. def smooth_l1_loss(pred, target, beta=1.0):
  10. """Smooth L1 loss.
  11. Args:
  12. pred (torch.Tensor): The prediction.
  13. target (torch.Tensor): The learning target of the prediction.
  14. beta (float, optional): The threshold in the piecewise function.
  15. Defaults to 1.0.
  16. Returns:
  17. torch.Tensor: Calculated loss
  18. """
  19. assert beta > 0
  20. if target.numel() == 0:
  21. return pred.sum() * 0
  22. assert pred.size() == target.size()
  23. diff = torch.abs(pred - target)
  24. loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
  25. diff - 0.5 * beta)
  26. return loss
  27. @mmcv.jit(derivate=True, coderize=True)
  28. @weighted_loss
  29. def l1_loss(pred, target):
  30. """L1 loss.
  31. Args:
  32. pred (torch.Tensor): The prediction.
  33. target (torch.Tensor): The learning target of the prediction.
  34. Returns:
  35. torch.Tensor: Calculated loss
  36. """
  37. if target.numel() == 0:
  38. return pred.sum() * 0
  39. assert pred.size() == target.size()
  40. loss = torch.abs(pred - target)
  41. return loss
  42. @LOSSES.register_module()
  43. class SmoothL1Loss(nn.Module):
  44. """Smooth L1 loss.
  45. Args:
  46. beta (float, optional): The threshold in the piecewise function.
  47. Defaults to 1.0.
  48. reduction (str, optional): The method to reduce the loss.
  49. Options are "none", "mean" and "sum". Defaults to "mean".
  50. loss_weight (float, optional): The weight of loss.
  51. """
  52. def __init__(self, beta=1.0, reduction='mean', loss_weight=1.0):
  53. super(SmoothL1Loss, self).__init__()
  54. self.beta = beta
  55. self.reduction = reduction
  56. self.loss_weight = loss_weight
  57. def forward(self,
  58. pred,
  59. target,
  60. weight=None,
  61. avg_factor=None,
  62. reduction_override=None,
  63. **kwargs):
  64. """Forward function.
  65. Args:
  66. pred (torch.Tensor): The prediction.
  67. target (torch.Tensor): The learning target of the prediction.
  68. weight (torch.Tensor, optional): The weight of loss for each
  69. prediction. Defaults to None.
  70. avg_factor (int, optional): Average factor that is used to average
  71. the loss. Defaults to None.
  72. reduction_override (str, optional): The reduction method used to
  73. override the original reduction method of the loss.
  74. Defaults to None.
  75. """
  76. assert reduction_override in (None, 'none', 'mean', 'sum')
  77. reduction = (
  78. reduction_override if reduction_override else self.reduction)
  79. loss_bbox = self.loss_weight * smooth_l1_loss(
  80. pred,
  81. target,
  82. weight,
  83. beta=self.beta,
  84. reduction=reduction,
  85. avg_factor=avg_factor,
  86. **kwargs)
  87. return loss_bbox
  88. @LOSSES.register_module()
  89. class L1Loss(nn.Module):
  90. """L1 loss.
  91. Args:
  92. reduction (str, optional): The method to reduce the loss.
  93. Options are "none", "mean" and "sum".
  94. loss_weight (float, optional): The weight of loss.
  95. """
  96. def __init__(self, reduction='mean', loss_weight=1.0):
  97. super(L1Loss, self).__init__()
  98. self.reduction = reduction
  99. self.loss_weight = loss_weight
  100. def forward(self,
  101. pred,
  102. target,
  103. weight=None,
  104. avg_factor=None,
  105. reduction_override=None):
  106. """Forward function.
  107. Args:
  108. pred (torch.Tensor): The prediction.
  109. target (torch.Tensor): The learning target of the prediction.
  110. weight (torch.Tensor, optional): The weight of loss for each
  111. prediction. Defaults to None.
  112. avg_factor (int, optional): Average factor that is used to average
  113. the loss. Defaults to None.
  114. reduction_override (str, optional): The reduction method used to
  115. override the original reduction method of the loss.
  116. Defaults to None.
  117. """
  118. assert reduction_override in (None, 'none', 'mean', 'sum')
  119. reduction = (
  120. reduction_override if reduction_override else self.reduction)
  121. loss_bbox = self.loss_weight * l1_loss(
  122. pred, target, weight, reduction=reduction, avg_factor=avg_factor)
  123. return loss_bbox

No Description

Contributors (3)