You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

seesaw_loss.py 10 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from ..builder import LOSSES
  6. from .accuracy import accuracy
  7. from .cross_entropy_loss import cross_entropy
  8. from .utils import weight_reduce_loss
  9. def seesaw_ce_loss(cls_score,
  10. labels,
  11. label_weights,
  12. cum_samples,
  13. num_classes,
  14. p,
  15. q,
  16. eps,
  17. reduction='mean',
  18. avg_factor=None):
  19. """Calculate the Seesaw CrossEntropy loss.
  20. Args:
  21. cls_score (torch.Tensor): The prediction with shape (N, C),
  22. C is the number of classes.
  23. labels (torch.Tensor): The learning label of the prediction.
  24. label_weights (torch.Tensor): Sample-wise loss weight.
  25. cum_samples (torch.Tensor): Cumulative samples for each category.
  26. num_classes (int): The number of classes.
  27. p (float): The ``p`` in the mitigation factor.
  28. q (float): The ``q`` in the compenstation factor.
  29. eps (float): The minimal value of divisor to smooth
  30. the computation of compensation factor
  31. reduction (str, optional): The method used to reduce the loss.
  32. avg_factor (int, optional): Average factor that is used to average
  33. the loss. Defaults to None.
  34. Returns:
  35. torch.Tensor: The calculated loss
  36. """
  37. assert cls_score.size(-1) == num_classes
  38. assert len(cum_samples) == num_classes
  39. onehot_labels = F.one_hot(labels, num_classes)
  40. seesaw_weights = cls_score.new_ones(onehot_labels.size())
  41. # mitigation factor
  42. if p > 0:
  43. sample_ratio_matrix = cum_samples[None, :].clamp(
  44. min=1) / cum_samples[:, None].clamp(min=1)
  45. index = (sample_ratio_matrix < 1.0).float()
  46. sample_weights = sample_ratio_matrix.pow(p) * index + (1 - index)
  47. mitigation_factor = sample_weights[labels.long(), :]
  48. seesaw_weights = seesaw_weights * mitigation_factor
  49. # compensation factor
  50. if q > 0:
  51. scores = F.softmax(cls_score.detach(), dim=1)
  52. self_scores = scores[
  53. torch.arange(0, len(scores)).to(scores.device).long(),
  54. labels.long()]
  55. score_matrix = scores / self_scores[:, None].clamp(min=eps)
  56. index = (score_matrix > 1.0).float()
  57. compensation_factor = score_matrix.pow(q) * index + (1 - index)
  58. seesaw_weights = seesaw_weights * compensation_factor
  59. cls_score = cls_score + (seesaw_weights.log() * (1 - onehot_labels))
  60. loss = F.cross_entropy(cls_score, labels, weight=None, reduction='none')
  61. if label_weights is not None:
  62. label_weights = label_weights.float()
  63. loss = weight_reduce_loss(
  64. loss, weight=label_weights, reduction=reduction, avg_factor=avg_factor)
  65. return loss
  66. @LOSSES.register_module()
  67. class SeesawLoss(nn.Module):
  68. """
  69. Seesaw Loss for Long-Tailed Instance Segmentation (CVPR 2021)
  70. arXiv: https://arxiv.org/abs/2008.10032
  71. Args:
  72. use_sigmoid (bool, optional): Whether the prediction uses sigmoid
  73. of softmax. Only False is supported.
  74. p (float, optional): The ``p`` in the mitigation factor.
  75. Defaults to 0.8.
  76. q (float, optional): The ``q`` in the compenstation factor.
  77. Defaults to 2.0.
  78. num_classes (int, optional): The number of classes.
  79. Default to 1203 for LVIS v1 dataset.
  80. eps (float, optional): The minimal value of divisor to smooth
  81. the computation of compensation factor
  82. reduction (str, optional): The method that reduces the loss to a
  83. scalar. Options are "none", "mean" and "sum".
  84. loss_weight (float, optional): The weight of the loss. Defaults to 1.0
  85. return_dict (bool, optional): Whether return the losses as a dict.
  86. Default to True.
  87. """
  88. def __init__(self,
  89. use_sigmoid=False,
  90. p=0.8,
  91. q=2.0,
  92. num_classes=1203,
  93. eps=1e-2,
  94. reduction='mean',
  95. loss_weight=1.0,
  96. return_dict=True):
  97. super(SeesawLoss, self).__init__()
  98. assert not use_sigmoid
  99. self.use_sigmoid = False
  100. self.p = p
  101. self.q = q
  102. self.num_classes = num_classes
  103. self.eps = eps
  104. self.reduction = reduction
  105. self.loss_weight = loss_weight
  106. self.return_dict = return_dict
  107. # 0 for pos, 1 for neg
  108. self.cls_criterion = seesaw_ce_loss
  109. # cumulative samples for each category
  110. self.register_buffer(
  111. 'cum_samples',
  112. torch.zeros(self.num_classes + 1, dtype=torch.float))
  113. # custom output channels of the classifier
  114. self.custom_cls_channels = True
  115. # custom activation of cls_score
  116. self.custom_activation = True
  117. # custom accuracy of the classsifier
  118. self.custom_accuracy = True
  119. def _split_cls_score(self, cls_score):
  120. # split cls_score to cls_score_classes and cls_score_objectness
  121. assert cls_score.size(-1) == self.num_classes + 2
  122. cls_score_classes = cls_score[..., :-2]
  123. cls_score_objectness = cls_score[..., -2:]
  124. return cls_score_classes, cls_score_objectness
  125. def get_cls_channels(self, num_classes):
  126. """Get custom classification channels.
  127. Args:
  128. num_classes (int): The number of classes.
  129. Returns:
  130. int: The custom classification channels.
  131. """
  132. assert num_classes == self.num_classes
  133. return num_classes + 2
  134. def get_activation(self, cls_score):
  135. """Get custom activation of cls_score.
  136. Args:
  137. cls_score (torch.Tensor): The prediction with shape (N, C + 2).
  138. Returns:
  139. torch.Tensor: The custom activation of cls_score with shape
  140. (N, C + 1).
  141. """
  142. cls_score_classes, cls_score_objectness = self._split_cls_score(
  143. cls_score)
  144. score_classes = F.softmax(cls_score_classes, dim=-1)
  145. score_objectness = F.softmax(cls_score_objectness, dim=-1)
  146. score_pos = score_objectness[..., [0]]
  147. score_neg = score_objectness[..., [1]]
  148. score_classes = score_classes * score_pos
  149. scores = torch.cat([score_classes, score_neg], dim=-1)
  150. return scores
  151. def get_accuracy(self, cls_score, labels):
  152. """Get custom accuracy w.r.t. cls_score and labels.
  153. Args:
  154. cls_score (torch.Tensor): The prediction with shape (N, C + 2).
  155. labels (torch.Tensor): The learning label of the prediction.
  156. Returns:
  157. Dict [str, torch.Tensor]: The accuracy for objectness and classes,
  158. respectively.
  159. """
  160. pos_inds = labels < self.num_classes
  161. obj_labels = (labels == self.num_classes).long()
  162. cls_score_classes, cls_score_objectness = self._split_cls_score(
  163. cls_score)
  164. acc_objectness = accuracy(cls_score_objectness, obj_labels)
  165. acc_classes = accuracy(cls_score_classes[pos_inds], labels[pos_inds])
  166. acc = dict()
  167. acc['acc_objectness'] = acc_objectness
  168. acc['acc_classes'] = acc_classes
  169. return acc
  170. def forward(self,
  171. cls_score,
  172. labels,
  173. label_weights=None,
  174. avg_factor=None,
  175. reduction_override=None):
  176. """Forward function.
  177. Args:
  178. cls_score (torch.Tensor): The prediction with shape (N, C + 2).
  179. labels (torch.Tensor): The learning label of the prediction.
  180. label_weights (torch.Tensor, optional): Sample-wise loss weight.
  181. avg_factor (int, optional): Average factor that is used to average
  182. the loss. Defaults to None.
  183. reduction (str, optional): The method used to reduce the loss.
  184. Options are "none", "mean" and "sum".
  185. Returns:
  186. torch.Tensor | Dict [str, torch.Tensor]:
  187. if return_dict == False: The calculated loss |
  188. if return_dict == True: The dict of calculated losses
  189. for objectness and classes, respectively.
  190. """
  191. assert reduction_override in (None, 'none', 'mean', 'sum')
  192. reduction = (
  193. reduction_override if reduction_override else self.reduction)
  194. assert cls_score.size(-1) == self.num_classes + 2
  195. pos_inds = labels < self.num_classes
  196. # 0 for pos, 1 for neg
  197. obj_labels = (labels == self.num_classes).long()
  198. # accumulate the samples for each category
  199. unique_labels = labels.unique()
  200. for u_l in unique_labels:
  201. inds_ = labels == u_l.item()
  202. self.cum_samples[u_l] += inds_.sum()
  203. if label_weights is not None:
  204. label_weights = label_weights.float()
  205. else:
  206. label_weights = labels.new_ones(labels.size(), dtype=torch.float)
  207. cls_score_classes, cls_score_objectness = self._split_cls_score(
  208. cls_score)
  209. # calculate loss_cls_classes (only need pos samples)
  210. if pos_inds.sum() > 0:
  211. loss_cls_classes = self.loss_weight * self.cls_criterion(
  212. cls_score_classes[pos_inds], labels[pos_inds],
  213. label_weights[pos_inds], self.cum_samples[:self.num_classes],
  214. self.num_classes, self.p, self.q, self.eps, reduction,
  215. avg_factor)
  216. else:
  217. loss_cls_classes = cls_score_classes[pos_inds].sum()
  218. # calculate loss_cls_objectness
  219. loss_cls_objectness = self.loss_weight * cross_entropy(
  220. cls_score_objectness, obj_labels, label_weights, reduction,
  221. avg_factor)
  222. if self.return_dict:
  223. loss_cls = dict()
  224. loss_cls['loss_cls_objectness'] = loss_cls_objectness
  225. loss_cls['loss_cls_classes'] = loss_cls_classes
  226. else:
  227. loss_cls = loss_cls_classes + loss_cls_objectness
  228. return loss_cls

No Description

Contributors (1)