You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_loss.py 7.9 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from mmcv.utils import digit_version
  5. from mmdet.models.losses import (BalancedL1Loss, CrossEntropyLoss, DiceLoss,
  6. DistributionFocalLoss, FocalLoss,
  7. GaussianFocalLoss,
  8. KnowledgeDistillationKLDivLoss, L1Loss,
  9. MSELoss, QualityFocalLoss, SeesawLoss,
  10. SmoothL1Loss, VarifocalLoss)
  11. from mmdet.models.losses.ghm_loss import GHMC, GHMR
  12. from mmdet.models.losses.iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss,
  13. GIoULoss, IoULoss)
  14. @pytest.mark.parametrize(
  15. 'loss_class', [IoULoss, BoundedIoULoss, GIoULoss, DIoULoss, CIoULoss])
  16. def test_iou_type_loss_zeros_weight(loss_class):
  17. pred = torch.rand((10, 4))
  18. target = torch.rand((10, 4))
  19. weight = torch.zeros(10)
  20. loss = loss_class()(pred, target, weight)
  21. assert loss == 0.
  22. @pytest.mark.parametrize('loss_class', [
  23. BalancedL1Loss, BoundedIoULoss, CIoULoss, CrossEntropyLoss, DIoULoss,
  24. FocalLoss, DistributionFocalLoss, MSELoss, SeesawLoss, GaussianFocalLoss,
  25. GIoULoss, IoULoss, L1Loss, QualityFocalLoss, VarifocalLoss, GHMR, GHMC,
  26. SmoothL1Loss, KnowledgeDistillationKLDivLoss, DiceLoss
  27. ])
  28. def test_loss_with_reduction_override(loss_class):
  29. pred = torch.rand((10, 4))
  30. target = torch.rand((10, 4)),
  31. weight = None
  32. with pytest.raises(AssertionError):
  33. # only reduction_override from [None, 'none', 'mean', 'sum']
  34. # is not allowed
  35. reduction_override = True
  36. loss_class()(
  37. pred, target, weight, reduction_override=reduction_override)
  38. @pytest.mark.parametrize('loss_class', [
  39. IoULoss, BoundedIoULoss, GIoULoss, DIoULoss, CIoULoss, MSELoss, L1Loss,
  40. SmoothL1Loss, BalancedL1Loss
  41. ])
  42. @pytest.mark.parametrize('input_shape', [(10, 4), (0, 4)])
  43. def test_regression_losses(loss_class, input_shape):
  44. pred = torch.rand(input_shape)
  45. target = torch.rand(input_shape)
  46. weight = torch.rand(input_shape)
  47. # Test loss forward
  48. loss = loss_class()(pred, target)
  49. assert isinstance(loss, torch.Tensor)
  50. # Test loss forward with weight
  51. loss = loss_class()(pred, target, weight)
  52. assert isinstance(loss, torch.Tensor)
  53. # Test loss forward with reduction_override
  54. loss = loss_class()(pred, target, reduction_override='mean')
  55. assert isinstance(loss, torch.Tensor)
  56. # Test loss forward with avg_factor
  57. loss = loss_class()(pred, target, avg_factor=10)
  58. assert isinstance(loss, torch.Tensor)
  59. with pytest.raises(ValueError):
  60. # loss can evaluate with avg_factor only if
  61. # reduction is None, 'none' or 'mean'.
  62. reduction_override = 'sum'
  63. loss_class()(
  64. pred, target, avg_factor=10, reduction_override=reduction_override)
  65. # Test loss forward with avg_factor and reduction
  66. for reduction_override in [None, 'none', 'mean']:
  67. loss_class()(
  68. pred, target, avg_factor=10, reduction_override=reduction_override)
  69. assert isinstance(loss, torch.Tensor)
  70. @pytest.mark.parametrize('loss_class', [FocalLoss, CrossEntropyLoss])
  71. @pytest.mark.parametrize('input_shape', [(10, 5), (0, 5)])
  72. def test_classification_losses(loss_class, input_shape):
  73. if input_shape[0] == 0 and digit_version(
  74. torch.__version__) < digit_version('1.5.0'):
  75. pytest.skip(
  76. f'CELoss in PyTorch {torch.__version__} does not support empty'
  77. f'tensor.')
  78. pred = torch.rand(input_shape)
  79. target = torch.randint(0, 5, (input_shape[0], ))
  80. # Test loss forward
  81. loss = loss_class()(pred, target)
  82. assert isinstance(loss, torch.Tensor)
  83. # Test loss forward with reduction_override
  84. loss = loss_class()(pred, target, reduction_override='mean')
  85. assert isinstance(loss, torch.Tensor)
  86. # Test loss forward with avg_factor
  87. loss = loss_class()(pred, target, avg_factor=10)
  88. assert isinstance(loss, torch.Tensor)
  89. with pytest.raises(ValueError):
  90. # loss can evaluate with avg_factor only if
  91. # reduction is None, 'none' or 'mean'.
  92. reduction_override = 'sum'
  93. loss_class()(
  94. pred, target, avg_factor=10, reduction_override=reduction_override)
  95. # Test loss forward with avg_factor and reduction
  96. for reduction_override in [None, 'none', 'mean']:
  97. loss_class()(
  98. pred, target, avg_factor=10, reduction_override=reduction_override)
  99. assert isinstance(loss, torch.Tensor)
  100. @pytest.mark.parametrize('loss_class', [GHMR])
  101. @pytest.mark.parametrize('input_shape', [(10, 4), (0, 4)])
  102. def test_GHMR_loss(loss_class, input_shape):
  103. pred = torch.rand(input_shape)
  104. target = torch.rand(input_shape)
  105. weight = torch.rand(input_shape)
  106. # Test loss forward
  107. loss = loss_class()(pred, target, weight)
  108. assert isinstance(loss, torch.Tensor)
  109. @pytest.mark.parametrize('use_sigmoid', [True, False])
  110. def test_loss_with_ignore_index(use_sigmoid):
  111. # Test cross_entropy loss
  112. loss_class = CrossEntropyLoss(
  113. use_sigmoid=use_sigmoid, use_mask=False, ignore_index=255)
  114. pred = torch.rand((10, 5))
  115. target = torch.randint(0, 5, (10, ))
  116. ignored_indices = torch.randint(0, 10, (2, ), dtype=torch.long)
  117. target[ignored_indices] = 255
  118. # Test loss forward with default ignore
  119. loss_with_ignore = loss_class(pred, target, reduction_override='sum')
  120. assert isinstance(loss_with_ignore, torch.Tensor)
  121. # Test loss forward with forward ignore
  122. target[ignored_indices] = 250
  123. loss_with_forward_ignore = loss_class(
  124. pred, target, ignore_index=250, reduction_override='sum')
  125. assert isinstance(loss_with_forward_ignore, torch.Tensor)
  126. # Verify correctness
  127. not_ignored_indices = (target != 250)
  128. pred = pred[not_ignored_indices]
  129. target = target[not_ignored_indices]
  130. loss = loss_class(pred, target, reduction_override='sum')
  131. assert torch.allclose(loss, loss_with_ignore)
  132. assert torch.allclose(loss, loss_with_forward_ignore)
  133. def test_dice_loss():
  134. loss_class = DiceLoss
  135. pred = torch.rand((10, 4, 4))
  136. target = torch.rand((10, 4, 4))
  137. weight = torch.rand((10))
  138. # Test loss forward
  139. loss = loss_class()(pred, target)
  140. assert isinstance(loss, torch.Tensor)
  141. # Test loss forward with weight
  142. loss = loss_class()(pred, target, weight)
  143. assert isinstance(loss, torch.Tensor)
  144. # Test loss forward with reduction_override
  145. loss = loss_class()(pred, target, reduction_override='mean')
  146. assert isinstance(loss, torch.Tensor)
  147. # Test loss forward with avg_factor
  148. loss = loss_class()(pred, target, avg_factor=10)
  149. assert isinstance(loss, torch.Tensor)
  150. with pytest.raises(ValueError):
  151. # loss can evaluate with avg_factor only if
  152. # reduction is None, 'none' or 'mean'.
  153. reduction_override = 'sum'
  154. loss_class()(
  155. pred, target, avg_factor=10, reduction_override=reduction_override)
  156. # Test loss forward with avg_factor and reduction
  157. for reduction_override in [None, 'none', 'mean']:
  158. loss_class()(
  159. pred, target, avg_factor=10, reduction_override=reduction_override)
  160. assert isinstance(loss, torch.Tensor)
  161. # Test loss forward with has_acted=False and use_sigmoid=False
  162. with pytest.raises(NotImplementedError):
  163. loss_class(use_sigmoid=False, activate=True)(pred, target)
  164. # Test loss forward with weight.ndim != loss.ndim
  165. with pytest.raises(AssertionError):
  166. weight = torch.rand((2, 8))
  167. loss_class()(pred, target, weight)
  168. # Test loss forward with len(weight) != len(pred)
  169. with pytest.raises(AssertionError):
  170. weight = torch.rand((8))
  171. loss_class()(pred, target, weight)

No Description

Contributors (3)