You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cross_entropy_loss.py 9.7 kB

2 years ago

  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from ..builder import LOSSES
  6. from .utils import weight_reduce_loss
  7. def cross_entropy(pred,
  8. label,
  9. weight=None,
  10. reduction='mean',
  11. avg_factor=None,
  12. class_weight=None,
  13. ignore_index=-100):
  14. """Calculate the CrossEntropy loss.
  15. Args:
  16. pred (torch.Tensor): The prediction with shape (N, C), C is the number
  17. of classes.
  18. label (torch.Tensor): The learning label of the prediction.
  19. weight (torch.Tensor, optional): Sample-wise loss weight.
  20. reduction (str, optional): The method used to reduce the loss.
  21. avg_factor (int, optional): Average factor that is used to average
  22. the loss. Defaults to None.
  23. class_weight (list[float], optional): The weight for each class.
  24. ignore_index (int | None): The label index to be ignored.
  25. If None, it will be set to default value. Default: -100.
  26. Returns:
  27. torch.Tensor: The calculated loss
  28. """
  29. # The default value of ignore_index is the same as F.cross_entropy
  30. ignore_index = -100 if ignore_index is None else ignore_index
  31. # element-wise losses
  32. loss = F.cross_entropy(
  33. pred,
  34. label,
  35. weight=class_weight,
  36. reduction='none',
  37. ignore_index=ignore_index)
  38. # apply weights and do the reduction
  39. if weight is not None:
  40. weight = weight.float()
  41. loss = weight_reduce_loss(
  42. loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
  43. return loss
  44. def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):
  45. """Expand onehot labels to match the size of prediction."""
  46. bin_labels = labels.new_full((labels.size(0), label_channels), 0)
  47. valid_mask = (labels >= 0) & (labels != ignore_index)
  48. inds = torch.nonzero(
  49. valid_mask & (labels < label_channels), as_tuple=False)
  50. if inds.numel() > 0:
  51. bin_labels[inds, labels[inds]] = 1
  52. valid_mask = valid_mask.view(-1, 1).expand(labels.size(0),
  53. label_channels).float()
  54. if label_weights is None:
  55. bin_label_weights = valid_mask
  56. else:
  57. bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)
  58. bin_label_weights *= valid_mask
  59. return bin_labels, bin_label_weights
  60. def binary_cross_entropy(pred,
  61. label,
  62. weight=None,
  63. reduction='mean',
  64. avg_factor=None,
  65. class_weight=None,
  66. ignore_index=-100):
  67. """Calculate the binary CrossEntropy loss.
  68. Args:
  69. pred (torch.Tensor): The prediction with shape (N, 1).
  70. label (torch.Tensor): The learning label of the prediction.
  71. weight (torch.Tensor, optional): Sample-wise loss weight.
  72. reduction (str, optional): The method used to reduce the loss.
  73. Options are "none", "mean" and "sum".
  74. avg_factor (int, optional): Average factor that is used to average
  75. the loss. Defaults to None.
  76. class_weight (list[float], optional): The weight for each class.
  77. ignore_index (int | None): The label index to be ignored.
  78. If None, it will be set to default value. Default: -100.
  79. Returns:
  80. torch.Tensor: The calculated loss.
  81. """
  82. # The default value of ignore_index is the same as F.cross_entropy
  83. ignore_index = -100 if ignore_index is None else ignore_index
  84. if pred.dim() != label.dim():
  85. label, weight = _expand_onehot_labels(label, weight, pred.size(-1),
  86. ignore_index)
  87. # weighted element-wise losses
  88. if weight is not None:
  89. weight = weight.float()
  90. loss = F.binary_cross_entropy_with_logits(
  91. pred, label.float(), pos_weight=class_weight, reduction='none')
  92. # do the reduction for the weighted loss
  93. loss = weight_reduce_loss(
  94. loss, weight, reduction=reduction, avg_factor=avg_factor)
  95. return loss
  96. def mask_cross_entropy(pred,
  97. target,
  98. label,
  99. reduction='mean',
  100. avg_factor=None,
  101. class_weight=None,
  102. ignore_index=None):
  103. """Calculate the CrossEntropy loss for masks.
  104. Args:
  105. pred (torch.Tensor): The prediction with shape (N, C, *), C is the
  106. number of classes. The trailing * indicates arbitrary shape.
  107. target (torch.Tensor): The learning label of the prediction.
  108. label (torch.Tensor): ``label`` indicates the class label of the mask
  109. corresponding object. This will be used to select the mask in the
  110. of the class which the object belongs to when the mask prediction
  111. if not class-agnostic.
  112. reduction (str, optional): The method used to reduce the loss.
  113. Options are "none", "mean" and "sum".
  114. avg_factor (int, optional): Average factor that is used to average
  115. the loss. Defaults to None.
  116. class_weight (list[float], optional): The weight for each class.
  117. ignore_index (None): Placeholder, to be consistent with other loss.
  118. Default: None.
  119. Returns:
  120. torch.Tensor: The calculated loss
  121. Example:
  122. >>> N, C = 3, 11
  123. >>> H, W = 2, 2
  124. >>> pred = torch.randn(N, C, H, W) * 1000
  125. >>> target = torch.rand(N, H, W)
  126. >>> label = torch.randint(0, C, size=(N,))
  127. >>> reduction = 'mean'
  128. >>> avg_factor = None
  129. >>> class_weights = None
  130. >>> loss = mask_cross_entropy(pred, target, label, reduction,
  131. >>> avg_factor, class_weights)
  132. >>> assert loss.shape == (1,)
  133. """
  134. assert ignore_index is None, 'BCE loss does not support ignore_index'
  135. # TODO: handle these two reserved arguments
  136. assert reduction == 'mean' and avg_factor is None
  137. num_rois = pred.size()[0]
  138. inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
  139. pred_slice = pred[inds, label].squeeze(1)
  140. return F.binary_cross_entropy_with_logits(
  141. pred_slice, target, weight=class_weight, reduction='mean')[None]
  142. @LOSSES.register_module()
  143. class CrossEntropyLoss(nn.Module):
  144. def __init__(self,
  145. use_sigmoid=False,
  146. use_mask=False,
  147. reduction='mean',
  148. class_weight=None,
  149. ignore_index=None,
  150. loss_weight=1.0):
  151. """CrossEntropyLoss.
  152. Args:
  153. use_sigmoid (bool, optional): Whether the prediction uses sigmoid
  154. of softmax. Defaults to False.
  155. use_mask (bool, optional): Whether to use mask cross entropy loss.
  156. Defaults to False.
  157. reduction (str, optional): . Defaults to 'mean'.
  158. Options are "none", "mean" and "sum".
  159. class_weight (list[float], optional): Weight of each class.
  160. Defaults to None.
  161. ignore_index (int | None): The label index to be ignored.
  162. Defaults to None.
  163. loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
  164. """
  165. super(CrossEntropyLoss, self).__init__()
  166. assert (use_sigmoid is False) or (use_mask is False)
  167. self.use_sigmoid = use_sigmoid
  168. self.use_mask = use_mask
  169. self.reduction = reduction
  170. self.loss_weight = loss_weight
  171. self.class_weight = class_weight
  172. self.ignore_index = ignore_index
  173. if self.use_sigmoid:
  174. self.cls_criterion = binary_cross_entropy
  175. elif self.use_mask:
  176. self.cls_criterion = mask_cross_entropy
  177. else:
  178. self.cls_criterion = cross_entropy
  179. def forward(self,
  180. cls_score,
  181. label,
  182. weight=None,
  183. avg_factor=None,
  184. reduction_override=None,
  185. ignore_index=None,
  186. **kwargs):
  187. """Forward function.
  188. Args:
  189. cls_score (torch.Tensor): The prediction.
  190. label (torch.Tensor): The learning label of the prediction.
  191. weight (torch.Tensor, optional): Sample-wise loss weight.
  192. avg_factor (int, optional): Average factor that is used to average
  193. the loss. Defaults to None.
  194. reduction_override (str, optional): The method used to reduce the
  195. loss. Options are "none", "mean" and "sum".
  196. ignore_index (int | None): The label index to be ignored.
  197. If not None, it will override the default value. Default: None.
  198. Returns:
  199. torch.Tensor: The calculated loss.
  200. """
  201. assert reduction_override in (None, 'none', 'mean', 'sum')
  202. reduction = (
  203. reduction_override if reduction_override else self.reduction)
  204. if ignore_index is None:
  205. ignore_index = self.ignore_index
  206. if self.class_weight is not None:
  207. class_weight = cls_score.new_tensor(
  208. self.class_weight, device=cls_score.device)
  209. else:
  210. class_weight = None
  211. loss_cls = self.loss_weight * self.cls_criterion(
  212. cls_score,
  213. label,
  214. weight,
  215. class_weight=class_weight,
  216. reduction=reduction,
  217. avg_factor=avg_factor,
  218. ignore_index=ignore_index,
  219. **kwargs)
  220. return loss_cls

No Description

Contributors (2)