You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

iou_loss.py 16 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import warnings
  4. import mmcv
  5. import torch
  6. import torch.nn as nn
  7. from mmdet.core import bbox_overlaps
  8. from ..builder import LOSSES
  9. from .utils import weighted_loss
  10. @mmcv.jit(derivate=True, coderize=True)
  11. @weighted_loss
  12. def iou_loss(pred, target, linear=False, mode='log', eps=1e-6):
  13. """IoU loss.
  14. Computing the IoU loss between a set of predicted bboxes and target bboxes.
  15. The loss is calculated as negative log of IoU.
  16. Args:
  17. pred (torch.Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  18. shape (n, 4).
  19. target (torch.Tensor): Corresponding gt bboxes, shape (n, 4).
  20. linear (bool, optional): If True, use linear scale of loss instead of
  21. log scale. Default: False.
  22. mode (str): Loss scaling mode, including "linear", "square", and "log".
  23. Default: 'log'
  24. eps (float): Eps to avoid log(0).
  25. Return:
  26. torch.Tensor: Loss tensor.
  27. """
  28. assert mode in ['linear', 'square', 'log']
  29. if linear:
  30. mode = 'linear'
  31. warnings.warn('DeprecationWarning: Setting "linear=True" in '
  32. 'iou_loss is deprecated, please use "mode=`linear`" '
  33. 'instead.')
  34. ious = bbox_overlaps(pred, target, is_aligned=True).clamp(min=eps)
  35. if mode == 'linear':
  36. loss = 1 - ious
  37. elif mode == 'square':
  38. loss = 1 - ious**2
  39. elif mode == 'log':
  40. loss = -ious.log()
  41. else:
  42. raise NotImplementedError
  43. return loss
  44. @mmcv.jit(derivate=True, coderize=True)
  45. @weighted_loss
  46. def bounded_iou_loss(pred, target, beta=0.2, eps=1e-3):
  47. """BIoULoss.
  48. This is an implementation of paper
  49. `Improving Object Localization with Fitness NMS and Bounded IoU Loss.
  50. <https://arxiv.org/abs/1711.00164>`_.
  51. Args:
  52. pred (torch.Tensor): Predicted bboxes.
  53. target (torch.Tensor): Target bboxes.
  54. beta (float): beta parameter in smoothl1.
  55. eps (float): eps to avoid NaN.
  56. """
  57. pred_ctrx = (pred[:, 0] + pred[:, 2]) * 0.5
  58. pred_ctry = (pred[:, 1] + pred[:, 3]) * 0.5
  59. pred_w = pred[:, 2] - pred[:, 0]
  60. pred_h = pred[:, 3] - pred[:, 1]
  61. with torch.no_grad():
  62. target_ctrx = (target[:, 0] + target[:, 2]) * 0.5
  63. target_ctry = (target[:, 1] + target[:, 3]) * 0.5
  64. target_w = target[:, 2] - target[:, 0]
  65. target_h = target[:, 3] - target[:, 1]
  66. dx = target_ctrx - pred_ctrx
  67. dy = target_ctry - pred_ctry
  68. loss_dx = 1 - torch.max(
  69. (target_w - 2 * dx.abs()) /
  70. (target_w + 2 * dx.abs() + eps), torch.zeros_like(dx))
  71. loss_dy = 1 - torch.max(
  72. (target_h - 2 * dy.abs()) /
  73. (target_h + 2 * dy.abs() + eps), torch.zeros_like(dy))
  74. loss_dw = 1 - torch.min(target_w / (pred_w + eps), pred_w /
  75. (target_w + eps))
  76. loss_dh = 1 - torch.min(target_h / (pred_h + eps), pred_h /
  77. (target_h + eps))
  78. # view(..., -1) does not work for empty tensor
  79. loss_comb = torch.stack([loss_dx, loss_dy, loss_dw, loss_dh],
  80. dim=-1).flatten(1)
  81. loss = torch.where(loss_comb < beta, 0.5 * loss_comb * loss_comb / beta,
  82. loss_comb - 0.5 * beta)
  83. return loss
  84. @mmcv.jit(derivate=True, coderize=True)
  85. @weighted_loss
  86. def giou_loss(pred, target, eps=1e-7):
  87. r"""`Generalized Intersection over Union: A Metric and A Loss for Bounding
  88. Box Regression <https://arxiv.org/abs/1902.09630>`_.
  89. Args:
  90. pred (torch.Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  91. shape (n, 4).
  92. target (torch.Tensor): Corresponding gt bboxes, shape (n, 4).
  93. eps (float): Eps to avoid log(0).
  94. Return:
  95. Tensor: Loss tensor.
  96. """
  97. gious = bbox_overlaps(pred, target, mode='giou', is_aligned=True, eps=eps)
  98. loss = 1 - gious
  99. return loss
  100. @mmcv.jit(derivate=True, coderize=True)
  101. @weighted_loss
  102. def diou_loss(pred, target, eps=1e-7):
  103. r"""`Implementation of Distance-IoU Loss: Faster and Better
  104. Learning for Bounding Box Regression, https://arxiv.org/abs/1911.08287`_.
  105. Code is modified from https://github.com/Zzh-tju/DIoU.
  106. Args:
  107. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  108. shape (n, 4).
  109. target (Tensor): Corresponding gt bboxes, shape (n, 4).
  110. eps (float): Eps to avoid log(0).
  111. Return:
  112. Tensor: Loss tensor.
  113. """
  114. # overlap
  115. lt = torch.max(pred[:, :2], target[:, :2])
  116. rb = torch.min(pred[:, 2:], target[:, 2:])
  117. wh = (rb - lt).clamp(min=0)
  118. overlap = wh[:, 0] * wh[:, 1]
  119. # union
  120. ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
  121. ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
  122. union = ap + ag - overlap + eps
  123. # IoU
  124. ious = overlap / union
  125. # enclose area
  126. enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
  127. enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
  128. enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
  129. cw = enclose_wh[:, 0]
  130. ch = enclose_wh[:, 1]
  131. c2 = cw**2 + ch**2 + eps
  132. b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
  133. b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
  134. b2_x1, b2_y1 = target[:, 0], target[:, 1]
  135. b2_x2, b2_y2 = target[:, 2], target[:, 3]
  136. left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
  137. right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
  138. rho2 = left + right
  139. # DIoU
  140. dious = ious - rho2 / c2
  141. loss = 1 - dious
  142. return loss
  143. @mmcv.jit(derivate=True, coderize=True)
  144. @weighted_loss
  145. def ciou_loss(pred, target, eps=1e-7):
  146. r"""`Implementation of paper `Enhancing Geometric Factors into
  147. Model Learning and Inference for Object Detection and Instance
  148. Segmentation <https://arxiv.org/abs/2005.03572>`_.
  149. Code is modified from https://github.com/Zzh-tju/CIoU.
  150. Args:
  151. pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
  152. shape (n, 4).
  153. target (Tensor): Corresponding gt bboxes, shape (n, 4).
  154. eps (float): Eps to avoid log(0).
  155. Return:
  156. Tensor: Loss tensor.
  157. """
  158. # overlap
  159. lt = torch.max(pred[:, :2], target[:, :2])
  160. rb = torch.min(pred[:, 2:], target[:, 2:])
  161. wh = (rb - lt).clamp(min=0)
  162. overlap = wh[:, 0] * wh[:, 1]
  163. # union
  164. ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
  165. ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
  166. union = ap + ag - overlap + eps
  167. # IoU
  168. ious = overlap / union
  169. # enclose area
  170. enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
  171. enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
  172. enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
  173. cw = enclose_wh[:, 0]
  174. ch = enclose_wh[:, 1]
  175. c2 = cw**2 + ch**2 + eps
  176. b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
  177. b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
  178. b2_x1, b2_y1 = target[:, 0], target[:, 1]
  179. b2_x2, b2_y2 = target[:, 2], target[:, 3]
  180. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  181. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  182. left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
  183. right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
  184. rho2 = left + right
  185. factor = 4 / math.pi**2
  186. v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
  187. with torch.no_grad():
  188. alpha = (ious > 0.5).float() * v / (1 - ious + v)
  189. # CIoU
  190. cious = ious - (rho2 / c2 + alpha * v)
  191. loss = 1 - cious.clamp(min=-1.0, max=1.0)
  192. return loss
  193. @LOSSES.register_module()
  194. class IoULoss(nn.Module):
  195. """IoULoss.
  196. Computing the IoU loss between a set of predicted bboxes and target bboxes.
  197. Args:
  198. linear (bool): If True, use linear scale of loss else determined
  199. by mode. Default: False.
  200. eps (float): Eps to avoid log(0).
  201. reduction (str): Options are "none", "mean" and "sum".
  202. loss_weight (float): Weight of loss.
  203. mode (str): Loss scaling mode, including "linear", "square", and "log".
  204. Default: 'log'
  205. """
  206. def __init__(self,
  207. linear=False,
  208. eps=1e-6,
  209. reduction='mean',
  210. loss_weight=1.0,
  211. mode='log'):
  212. super(IoULoss, self).__init__()
  213. assert mode in ['linear', 'square', 'log']
  214. if linear:
  215. mode = 'linear'
  216. warnings.warn('DeprecationWarning: Setting "linear=True" in '
  217. 'IOULoss is deprecated, please use "mode=`linear`" '
  218. 'instead.')
  219. self.mode = mode
  220. self.linear = linear
  221. self.eps = eps
  222. self.reduction = reduction
  223. self.loss_weight = loss_weight
  224. def forward(self,
  225. pred,
  226. target,
  227. weight=None,
  228. avg_factor=None,
  229. reduction_override=None,
  230. **kwargs):
  231. """Forward function.
  232. Args:
  233. pred (torch.Tensor): The prediction.
  234. target (torch.Tensor): The learning target of the prediction.
  235. weight (torch.Tensor, optional): The weight of loss for each
  236. prediction. Defaults to None.
  237. avg_factor (int, optional): Average factor that is used to average
  238. the loss. Defaults to None.
  239. reduction_override (str, optional): The reduction method used to
  240. override the original reduction method of the loss.
  241. Defaults to None. Options are "none", "mean" and "sum".
  242. """
  243. assert reduction_override in (None, 'none', 'mean', 'sum')
  244. reduction = (
  245. reduction_override if reduction_override else self.reduction)
  246. if (weight is not None) and (not torch.any(weight > 0)) and (
  247. reduction != 'none'):
  248. if pred.dim() == weight.dim() + 1:
  249. weight = weight.unsqueeze(1)
  250. return (pred * weight).sum() # 0
  251. if weight is not None and weight.dim() > 1:
  252. # TODO: remove this in the future
  253. # reduce the weight of shape (n, 4) to (n,) to match the
  254. # iou_loss of shape (n,)
  255. assert weight.shape == pred.shape
  256. weight = weight.mean(-1)
  257. loss = self.loss_weight * iou_loss(
  258. pred,
  259. target,
  260. weight,
  261. mode=self.mode,
  262. eps=self.eps,
  263. reduction=reduction,
  264. avg_factor=avg_factor,
  265. **kwargs)
  266. return loss
  267. @LOSSES.register_module()
  268. class BoundedIoULoss(nn.Module):
  269. def __init__(self, beta=0.2, eps=1e-3, reduction='mean', loss_weight=1.0):
  270. super(BoundedIoULoss, self).__init__()
  271. self.beta = beta
  272. self.eps = eps
  273. self.reduction = reduction
  274. self.loss_weight = loss_weight
  275. def forward(self,
  276. pred,
  277. target,
  278. weight=None,
  279. avg_factor=None,
  280. reduction_override=None,
  281. **kwargs):
  282. if weight is not None and not torch.any(weight > 0):
  283. if pred.dim() == weight.dim() + 1:
  284. weight = weight.unsqueeze(1)
  285. return (pred * weight).sum() # 0
  286. assert reduction_override in (None, 'none', 'mean', 'sum')
  287. reduction = (
  288. reduction_override if reduction_override else self.reduction)
  289. loss = self.loss_weight * bounded_iou_loss(
  290. pred,
  291. target,
  292. weight,
  293. beta=self.beta,
  294. eps=self.eps,
  295. reduction=reduction,
  296. avg_factor=avg_factor,
  297. **kwargs)
  298. return loss
  299. @LOSSES.register_module()
  300. class GIoULoss(nn.Module):
  301. def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
  302. super(GIoULoss, self).__init__()
  303. self.eps = eps
  304. self.reduction = reduction
  305. self.loss_weight = loss_weight
  306. def forward(self,
  307. pred,
  308. target,
  309. weight=None,
  310. avg_factor=None,
  311. reduction_override=None,
  312. **kwargs):
  313. if weight is not None and not torch.any(weight > 0):
  314. if pred.dim() == weight.dim() + 1:
  315. weight = weight.unsqueeze(1)
  316. return (pred * weight).sum() # 0
  317. assert reduction_override in (None, 'none', 'mean', 'sum')
  318. reduction = (
  319. reduction_override if reduction_override else self.reduction)
  320. if weight is not None and weight.dim() > 1:
  321. # TODO: remove this in the future
  322. # reduce the weight of shape (n, 4) to (n,) to match the
  323. # giou_loss of shape (n,)
  324. assert weight.shape == pred.shape
  325. weight = weight.mean(-1)
  326. loss = self.loss_weight * giou_loss(
  327. pred,
  328. target,
  329. weight,
  330. eps=self.eps,
  331. reduction=reduction,
  332. avg_factor=avg_factor,
  333. **kwargs)
  334. return loss
  335. @LOSSES.register_module()
  336. class DIoULoss(nn.Module):
  337. def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
  338. super(DIoULoss, self).__init__()
  339. self.eps = eps
  340. self.reduction = reduction
  341. self.loss_weight = loss_weight
  342. def forward(self,
  343. pred,
  344. target,
  345. weight=None,
  346. avg_factor=None,
  347. reduction_override=None,
  348. **kwargs):
  349. if weight is not None and not torch.any(weight > 0):
  350. if pred.dim() == weight.dim() + 1:
  351. weight = weight.unsqueeze(1)
  352. return (pred * weight).sum() # 0
  353. assert reduction_override in (None, 'none', 'mean', 'sum')
  354. reduction = (
  355. reduction_override if reduction_override else self.reduction)
  356. if weight is not None and weight.dim() > 1:
  357. # TODO: remove this in the future
  358. # reduce the weight of shape (n, 4) to (n,) to match the
  359. # giou_loss of shape (n,)
  360. assert weight.shape == pred.shape
  361. weight = weight.mean(-1)
  362. loss = self.loss_weight * diou_loss(
  363. pred,
  364. target,
  365. weight,
  366. eps=self.eps,
  367. reduction=reduction,
  368. avg_factor=avg_factor,
  369. **kwargs)
  370. return loss
  371. @LOSSES.register_module()
  372. class CIoULoss(nn.Module):
  373. def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
  374. super(CIoULoss, self).__init__()
  375. self.eps = eps
  376. self.reduction = reduction
  377. self.loss_weight = loss_weight
  378. def forward(self,
  379. pred,
  380. target,
  381. weight=None,
  382. avg_factor=None,
  383. reduction_override=None,
  384. **kwargs):
  385. if weight is not None and not torch.any(weight > 0):
  386. if pred.dim() == weight.dim() + 1:
  387. weight = weight.unsqueeze(1)
  388. return (pred * weight).sum() # 0
  389. assert reduction_override in (None, 'none', 'mean', 'sum')
  390. reduction = (
  391. reduction_override if reduction_override else self.reduction)
  392. if weight is not None and weight.dim() > 1:
  393. # TODO: remove this in the future
  394. # reduce the weight of shape (n, 4) to (n,) to match the
  395. # giou_loss of shape (n,)
  396. assert weight.shape == pred.shape
  397. weight = weight.mean(-1)
  398. loss = self.loss_weight * ciou_loss(
  399. pred,
  400. target,
  401. weight,
  402. eps=self.eps,
  403. reduction=reduction,
  404. avg_factor=avg_factor,
  405. **kwargs)
  406. return loss

No Description

Contributors (2)