You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 3.1 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import functools
  3. import mmcv
  4. import torch.nn.functional as F
  5. def reduce_loss(loss, reduction):
  6. """Reduce loss as specified.
  7. Args:
  8. loss (Tensor): Elementwise loss tensor.
  9. reduction (str): Options are "none", "mean" and "sum".
  10. Return:
  11. Tensor: Reduced loss tensor.
  12. """
  13. reduction_enum = F._Reduction.get_enum(reduction)
  14. # none: 0, elementwise_mean:1, sum: 2
  15. if reduction_enum == 0:
  16. return loss
  17. elif reduction_enum == 1:
  18. return loss.mean()
  19. elif reduction_enum == 2:
  20. return loss.sum()
  21. @mmcv.jit(derivate=True, coderize=True)
  22. def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
  23. """Apply element-wise weight and reduce loss.
  24. Args:
  25. loss (Tensor): Element-wise loss.
  26. weight (Tensor): Element-wise weights.
  27. reduction (str): Same as built-in losses of PyTorch.
  28. avg_factor (float): Average factor when computing the mean of losses.
  29. Returns:
  30. Tensor: Processed loss values.
  31. """
  32. # if weight is specified, apply element-wise weight
  33. if weight is not None:
  34. loss = loss * weight
  35. # if avg_factor is not specified, just reduce the loss
  36. if avg_factor is None:
  37. loss = reduce_loss(loss, reduction)
  38. else:
  39. # if reduction is mean, then average the loss by avg_factor
  40. if reduction == 'mean':
  41. loss = loss.sum() / avg_factor
  42. # if reduction is 'none', then do nothing, otherwise raise an error
  43. elif reduction != 'none':
  44. raise ValueError('avg_factor can not be used with reduction="sum"')
  45. return loss
  46. def weighted_loss(loss_func):
  47. """Create a weighted version of a given loss function.
  48. To use this decorator, the loss function must have the signature like
  49. `loss_func(pred, target, **kwargs)`. The function only needs to compute
  50. element-wise loss without any reduction. This decorator will add weight
  51. and reduction arguments to the function. The decorated function will have
  52. the signature like `loss_func(pred, target, weight=None, reduction='mean',
  53. avg_factor=None, **kwargs)`.
  54. :Example:
  55. >>> import torch
  56. >>> @weighted_loss
  57. >>> def l1_loss(pred, target):
  58. >>> return (pred - target).abs()
  59. >>> pred = torch.Tensor([0, 2, 3])
  60. >>> target = torch.Tensor([1, 1, 1])
  61. >>> weight = torch.Tensor([1, 0, 1])
  62. >>> l1_loss(pred, target)
  63. tensor(1.3333)
  64. >>> l1_loss(pred, target, weight)
  65. tensor(1.)
  66. >>> l1_loss(pred, target, reduction='none')
  67. tensor([1., 1., 2.])
  68. >>> l1_loss(pred, target, weight, avg_factor=2)
  69. tensor(1.5000)
  70. """
  71. @functools.wraps(loss_func)
  72. def wrapper(pred,
  73. target,
  74. weight=None,
  75. reduction='mean',
  76. avg_factor=None,
  77. **kwargs):
  78. # get element-wise loss
  79. loss = loss_func(pred, target, **kwargs)
  80. loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
  81. return loss
  82. return wrapper

No Description

Contributors (3)