You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dice_loss.py 4.3 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch
  3. import torch.nn as nn
  4. from ..builder import LOSSES
  5. from .utils import weight_reduce_loss
  6. def dice_loss(pred,
  7. target,
  8. weight=None,
  9. eps=1e-3,
  10. reduction='mean',
  11. avg_factor=None):
  12. """Calculate dice loss, which is proposed in
  13. `V-Net: Fully Convolutional Neural Networks for Volumetric
  14. Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
  15. Args:
  16. pred (torch.Tensor): The prediction, has a shape (n, *)
  17. target (torch.Tensor): The learning label of the prediction,
  18. shape (n, *), same shape of pred.
  19. weight (torch.Tensor, optional): The weight of loss for each
  20. prediction, has a shape (n,). Defaults to None.
  21. eps (float): Avoid dividing by zero. Default: 1e-3.
  22. reduction (str, optional): The method used to reduce the loss into
  23. a scalar. Defaults to 'mean'.
  24. Options are "none", "mean" and "sum".
  25. avg_factor (int, optional): Average factor that is used to average
  26. the loss. Defaults to None.
  27. """
  28. input = pred.flatten(1)
  29. target = target.flatten(1).float()
  30. a = torch.sum(input * target, 1)
  31. b = torch.sum(input * input, 1) + eps
  32. c = torch.sum(target * target, 1) + eps
  33. d = (2 * a) / (b + c)
  34. loss = 1 - d
  35. if weight is not None:
  36. assert weight.ndim == loss.ndim
  37. assert len(weight) == len(pred)
  38. loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
  39. return loss
  40. @LOSSES.register_module()
  41. class DiceLoss(nn.Module):
  42. def __init__(self,
  43. use_sigmoid=True,
  44. activate=True,
  45. reduction='mean',
  46. loss_weight=1.0,
  47. eps=1e-3):
  48. """`Dice Loss, which is proposed in
  49. `V-Net: Fully Convolutional Neural Networks for Volumetric
  50. Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
  51. Args:
  52. use_sigmoid (bool, optional): Whether to the prediction is
  53. used for sigmoid or softmax. Defaults to True.
  54. activate (bool): Whether to activate the predictions inside,
  55. this will disable the inside sigmoid operation.
  56. Defaults to True.
  57. reduction (str, optional): The method used
  58. to reduce the loss. Options are "none",
  59. "mean" and "sum". Defaults to 'mean'.
  60. loss_weight (float, optional): Weight of loss. Defaults to 1.0.
  61. eps (float): Avoid dividing by zero. Defaults to 1e-3.
  62. """
  63. super(DiceLoss, self).__init__()
  64. self.use_sigmoid = use_sigmoid
  65. self.reduction = reduction
  66. self.loss_weight = loss_weight
  67. self.eps = eps
  68. self.activate = activate
  69. def forward(self,
  70. pred,
  71. target,
  72. weight=None,
  73. reduction_override=None,
  74. avg_factor=None):
  75. """Forward function.
  76. Args:
  77. pred (torch.Tensor): The prediction, has a shape (n, *).
  78. target (torch.Tensor): The label of the prediction,
  79. shape (n, *), same shape of pred.
  80. weight (torch.Tensor, optional): The weight of loss for each
  81. prediction, has a shape (n,). Defaults to None.
  82. avg_factor (int, optional): Average factor that is used to average
  83. the loss. Defaults to None.
  84. reduction_override (str, optional): The reduction method used to
  85. override the original reduction method of the loss.
  86. Options are "none", "mean" and "sum".
  87. Returns:
  88. torch.Tensor: The calculated loss
  89. """
  90. assert reduction_override in (None, 'none', 'mean', 'sum')
  91. reduction = (
  92. reduction_override if reduction_override else self.reduction)
  93. if self.activate:
  94. if self.use_sigmoid:
  95. pred = pred.sigmoid()
  96. else:
  97. raise NotImplementedError
  98. loss = self.loss_weight * dice_loss(
  99. pred,
  100. target,
  101. weight,
  102. eps=self.eps,
  103. reduction=reduction,
  104. avg_factor=avg_factor)
  105. return loss

No Description

Contributors (3)