# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn from ..builder import LOSSES from .utils import weight_reduce_loss def dice_loss(pred, target, weight=None, eps=1e-3, reduction='mean', avg_factor=None): """Calculate dice loss, which is proposed in `V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation `_. Args: pred (torch.Tensor): The prediction, has a shape (n, *) target (torch.Tensor): The learning label of the prediction, shape (n, *), same shape of pred. weight (torch.Tensor, optional): The weight of loss for each prediction, has a shape (n,). Defaults to None. eps (float): Avoid dividing by zero. Default: 1e-3. reduction (str, optional): The method used to reduce the loss into a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. """ input = pred.flatten(1) target = target.flatten(1).float() a = torch.sum(input * target, 1) b = torch.sum(input * input, 1) + eps c = torch.sum(target * target, 1) + eps d = (2 * a) / (b + c) loss = 1 - d if weight is not None: assert weight.ndim == loss.ndim assert len(weight) == len(pred) loss = weight_reduce_loss(loss, weight, reduction, avg_factor) return loss @LOSSES.register_module() class DiceLoss(nn.Module): def __init__(self, use_sigmoid=True, activate=True, reduction='mean', loss_weight=1.0, eps=1e-3): """`Dice Loss, which is proposed in `V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation `_. Args: use_sigmoid (bool, optional): Whether to the prediction is used for sigmoid or softmax. Defaults to True. activate (bool): Whether to activate the predictions inside, this will disable the inside sigmoid operation. Defaults to True. reduction (str, optional): The method used to reduce the loss. Options are "none", "mean" and "sum". Defaults to 'mean'. loss_weight (float, optional): Weight of loss. Defaults to 1.0. eps (float): Avoid dividing by zero. Defaults to 1e-3. """ super(DiceLoss, self).__init__() self.use_sigmoid = use_sigmoid self.reduction = reduction self.loss_weight = loss_weight self.eps = eps self.activate = activate def forward(self, pred, target, weight=None, reduction_override=None, avg_factor=None): """Forward function. Args: pred (torch.Tensor): The prediction, has a shape (n, *). target (torch.Tensor): The label of the prediction, shape (n, *), same shape of pred. weight (torch.Tensor, optional): The weight of loss for each prediction, has a shape (n,). Defaults to None. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Options are "none", "mean" and "sum". Returns: torch.Tensor: The calculated loss """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) if self.activate: if self.use_sigmoid: pred = pred.sigmoid() else: raise NotImplementedError loss = self.loss_weight * dice_loss( pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor) return loss