You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_losses.py 8.7 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from mmdet.models import Accuracy, build_loss
  5. def test_ce_loss():
  6. # use_mask and use_sigmoid cannot be true at the same time
  7. with pytest.raises(AssertionError):
  8. loss_cfg = dict(
  9. type='CrossEntropyLoss',
  10. use_mask=True,
  11. use_sigmoid=True,
  12. loss_weight=1.0)
  13. build_loss(loss_cfg)
  14. # test loss with class weights
  15. loss_cls_cfg = dict(
  16. type='CrossEntropyLoss',
  17. use_sigmoid=False,
  18. class_weight=[0.8, 0.2],
  19. loss_weight=1.0)
  20. loss_cls = build_loss(loss_cls_cfg)
  21. fake_pred = torch.Tensor([[100, -100]])
  22. fake_label = torch.Tensor([1]).long()
  23. assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
  24. loss_cls_cfg = dict(
  25. type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
  26. loss_cls = build_loss(loss_cls_cfg)
  27. assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.))
  28. def test_varifocal_loss():
  29. # only sigmoid version of VarifocalLoss is implemented
  30. with pytest.raises(AssertionError):
  31. loss_cfg = dict(
  32. type='VarifocalLoss', use_sigmoid=False, loss_weight=1.0)
  33. build_loss(loss_cfg)
  34. # test that alpha should be greater than 0
  35. with pytest.raises(AssertionError):
  36. loss_cfg = dict(
  37. type='VarifocalLoss',
  38. alpha=-0.75,
  39. gamma=2.0,
  40. use_sigmoid=True,
  41. loss_weight=1.0)
  42. build_loss(loss_cfg)
  43. # test that pred and target should be of the same size
  44. loss_cls_cfg = dict(
  45. type='VarifocalLoss',
  46. use_sigmoid=True,
  47. alpha=0.75,
  48. gamma=2.0,
  49. iou_weighted=True,
  50. reduction='mean',
  51. loss_weight=1.0)
  52. loss_cls = build_loss(loss_cls_cfg)
  53. with pytest.raises(AssertionError):
  54. fake_pred = torch.Tensor([[100.0, -100.0]])
  55. fake_target = torch.Tensor([[1.0]])
  56. loss_cls(fake_pred, fake_target)
  57. # test the calculation
  58. loss_cls = build_loss(loss_cls_cfg)
  59. fake_pred = torch.Tensor([[100.0, -100.0]])
  60. fake_target = torch.Tensor([[1.0, 0.0]])
  61. assert torch.allclose(loss_cls(fake_pred, fake_target), torch.tensor(0.0))
  62. # test the loss with weights
  63. loss_cls = build_loss(loss_cls_cfg)
  64. fake_pred = torch.Tensor([[0.0, 100.0]])
  65. fake_target = torch.Tensor([[1.0, 1.0]])
  66. fake_weight = torch.Tensor([0.0, 1.0])
  67. assert torch.allclose(
  68. loss_cls(fake_pred, fake_target, fake_weight), torch.tensor(0.0))
  69. def test_kd_loss():
  70. # test that temperature should be greater than 1
  71. with pytest.raises(AssertionError):
  72. loss_cfg = dict(
  73. type='KnowledgeDistillationKLDivLoss', loss_weight=1.0, T=0.5)
  74. build_loss(loss_cfg)
  75. # test that pred and target should be of the same size
  76. loss_cls_cfg = dict(
  77. type='KnowledgeDistillationKLDivLoss', loss_weight=1.0, T=1)
  78. loss_cls = build_loss(loss_cls_cfg)
  79. with pytest.raises(AssertionError):
  80. fake_pred = torch.Tensor([[100, -100]])
  81. fake_label = torch.Tensor([1]).long()
  82. loss_cls(fake_pred, fake_label)
  83. # test the calculation
  84. loss_cls = build_loss(loss_cls_cfg)
  85. fake_pred = torch.Tensor([[100.0, 100.0]])
  86. fake_target = torch.Tensor([[1.0, 1.0]])
  87. assert torch.allclose(loss_cls(fake_pred, fake_target), torch.tensor(0.0))
  88. # test the loss with weights
  89. loss_cls = build_loss(loss_cls_cfg)
  90. fake_pred = torch.Tensor([[100.0, -100.0], [100.0, 100.0]])
  91. fake_target = torch.Tensor([[1.0, 0.0], [1.0, 1.0]])
  92. fake_weight = torch.Tensor([0.0, 1.0])
  93. assert torch.allclose(
  94. loss_cls(fake_pred, fake_target, fake_weight), torch.tensor(0.0))
  95. def test_seesaw_loss():
  96. # only softmax version of Seesaw Loss is implemented
  97. with pytest.raises(AssertionError):
  98. loss_cfg = dict(type='SeesawLoss', use_sigmoid=True, loss_weight=1.0)
  99. build_loss(loss_cfg)
  100. # test that cls_score.size(-1) == num_classes + 2
  101. loss_cls_cfg = dict(
  102. type='SeesawLoss', p=0.0, q=0.0, loss_weight=1.0, num_classes=2)
  103. loss_cls = build_loss(loss_cls_cfg)
  104. # the length of fake_pred should be num_classes + 2 = 4
  105. with pytest.raises(AssertionError):
  106. fake_pred = torch.Tensor([[-100, 100]])
  107. fake_label = torch.Tensor([1]).long()
  108. loss_cls(fake_pred, fake_label)
  109. # the length of fake_pred should be num_classes + 2 = 4
  110. with pytest.raises(AssertionError):
  111. fake_pred = torch.Tensor([[-100, 100, -100]])
  112. fake_label = torch.Tensor([1]).long()
  113. loss_cls(fake_pred, fake_label)
  114. # test the calculation without p and q
  115. loss_cls_cfg = dict(
  116. type='SeesawLoss', p=0.0, q=0.0, loss_weight=1.0, num_classes=2)
  117. loss_cls = build_loss(loss_cls_cfg)
  118. fake_pred = torch.Tensor([[-100, 100, -100, 100]])
  119. fake_label = torch.Tensor([1]).long()
  120. loss = loss_cls(fake_pred, fake_label)
  121. assert torch.allclose(loss['loss_cls_objectness'], torch.tensor(200.))
  122. assert torch.allclose(loss['loss_cls_classes'], torch.tensor(0.))
  123. # test the calculation with p and without q
  124. loss_cls_cfg = dict(
  125. type='SeesawLoss', p=1.0, q=0.0, loss_weight=1.0, num_classes=2)
  126. loss_cls = build_loss(loss_cls_cfg)
  127. fake_pred = torch.Tensor([[-100, 100, -100, 100]])
  128. fake_label = torch.Tensor([0]).long()
  129. loss_cls.cum_samples[0] = torch.exp(torch.Tensor([20]))
  130. loss = loss_cls(fake_pred, fake_label)
  131. assert torch.allclose(loss['loss_cls_objectness'], torch.tensor(200.))
  132. assert torch.allclose(loss['loss_cls_classes'], torch.tensor(180.))
  133. # test the calculation with q and without p
  134. loss_cls_cfg = dict(
  135. type='SeesawLoss', p=0.0, q=1.0, loss_weight=1.0, num_classes=2)
  136. loss_cls = build_loss(loss_cls_cfg)
  137. fake_pred = torch.Tensor([[-100, 100, -100, 100]])
  138. fake_label = torch.Tensor([0]).long()
  139. loss = loss_cls(fake_pred, fake_label)
  140. assert torch.allclose(loss['loss_cls_objectness'], torch.tensor(200.))
  141. assert torch.allclose(loss['loss_cls_classes'],
  142. torch.tensor(200.) + torch.tensor(100.).log())
  143. # test the others
  144. loss_cls_cfg = dict(
  145. type='SeesawLoss',
  146. p=0.0,
  147. q=1.0,
  148. loss_weight=1.0,
  149. num_classes=2,
  150. return_dict=False)
  151. loss_cls = build_loss(loss_cls_cfg)
  152. fake_pred = torch.Tensor([[100, -100, 100, -100]])
  153. fake_label = torch.Tensor([0]).long()
  154. loss = loss_cls(fake_pred, fake_label)
  155. acc = loss_cls.get_accuracy(fake_pred, fake_label)
  156. act = loss_cls.get_activation(fake_pred)
  157. assert torch.allclose(loss, torch.tensor(0.))
  158. assert torch.allclose(acc['acc_objectness'], torch.tensor(100.))
  159. assert torch.allclose(acc['acc_classes'], torch.tensor(100.))
  160. assert torch.allclose(act, torch.tensor([1., 0., 0.]))
  161. def test_accuracy():
  162. # test for empty pred
  163. pred = torch.empty(0, 4)
  164. label = torch.empty(0)
  165. accuracy = Accuracy(topk=1)
  166. acc = accuracy(pred, label)
  167. assert acc.item() == 0
  168. pred = torch.Tensor([[0.2, 0.3, 0.6, 0.5], [0.1, 0.1, 0.2, 0.6],
  169. [0.9, 0.0, 0.0, 0.1], [0.4, 0.7, 0.1, 0.1],
  170. [0.0, 0.0, 0.99, 0]])
  171. # test for top1
  172. true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
  173. accuracy = Accuracy(topk=1)
  174. acc = accuracy(pred, true_label)
  175. assert acc.item() == 100
  176. # test for top1 with score thresh=0.8
  177. true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
  178. accuracy = Accuracy(topk=1, thresh=0.8)
  179. acc = accuracy(pred, true_label)
  180. assert acc.item() == 40
  181. # test for top2
  182. accuracy = Accuracy(topk=2)
  183. label = torch.Tensor([3, 2, 0, 0, 2]).long()
  184. acc = accuracy(pred, label)
  185. assert acc.item() == 100
  186. # test for both top1 and top2
  187. accuracy = Accuracy(topk=(1, 2))
  188. true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
  189. acc = accuracy(pred, true_label)
  190. for a in acc:
  191. assert a.item() == 100
  192. # topk is larger than pred class number
  193. with pytest.raises(AssertionError):
  194. accuracy = Accuracy(topk=5)
  195. accuracy(pred, true_label)
  196. # wrong topk type
  197. with pytest.raises(AssertionError):
  198. accuracy = Accuracy(topk='wrong type')
  199. accuracy(pred, true_label)
  200. # label size is larger than required
  201. with pytest.raises(AssertionError):
  202. label = torch.Tensor([2, 3, 0, 1, 2, 0]).long() # size mismatch
  203. accuracy = Accuracy()
  204. accuracy(pred, label)
  205. # wrong pred dimension
  206. with pytest.raises(AssertionError):
  207. accuracy = Accuracy()
  208. accuracy(pred[:, :, None], true_label)

No Description

Contributors (3)