You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mse_loss.py 1.9 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from ..builder import LOSSES
  5. from .utils import weighted_loss
  6. @weighted_loss
  7. def mse_loss(pred, target):
  8. """Warpper of mse loss."""
  9. return F.mse_loss(pred, target, reduction='none')
  10. @LOSSES.register_module()
  11. class MSELoss(nn.Module):
  12. """MSELoss.
  13. Args:
  14. reduction (str, optional): The method that reduces the loss to a
  15. scalar. Options are "none", "mean" and "sum".
  16. loss_weight (float, optional): The weight of the loss. Defaults to 1.0
  17. """
  18. def __init__(self, reduction='mean', loss_weight=1.0):
  19. super().__init__()
  20. self.reduction = reduction
  21. self.loss_weight = loss_weight
  22. def forward(self,
  23. pred,
  24. target,
  25. weight=None,
  26. avg_factor=None,
  27. reduction_override=None):
  28. """Forward function of loss.
  29. Args:
  30. pred (torch.Tensor): The prediction.
  31. target (torch.Tensor): The learning target of the prediction.
  32. weight (torch.Tensor, optional): Weight of the loss for each
  33. prediction. Defaults to None.
  34. avg_factor (int, optional): Average factor that is used to average
  35. the loss. Defaults to None.
  36. reduction_override (str, optional): The reduction method used to
  37. override the original reduction method of the loss.
  38. Defaults to None.
  39. Returns:
  40. torch.Tensor: The calculated loss
  41. """
  42. assert reduction_override in (None, 'none', 'mean', 'sum')
  43. reduction = (
  44. reduction_override if reduction_override else self.reduction)
  45. loss = self.loss_weight * mse_loss(
  46. pred, target, weight, reduction=reduction, avg_factor=avg_factor)
  47. return loss

No Description

Contributors (3)