You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_nms.py 2.5 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import pytest
  2. import torch
  3. from mmdet.core.post_processing import mask_matrix_nms
  4. def _create_mask(N, h, w):
  5. masks = torch.rand((N, h, w)) > 0.5
  6. labels = torch.rand(N)
  7. scores = torch.rand(N)
  8. return masks, labels, scores
  9. def test_nms_input_errors():
  10. with pytest.raises(AssertionError):
  11. mask_matrix_nms(
  12. torch.rand((10, 28, 28)), torch.rand(11), torch.rand(11))
  13. with pytest.raises(AssertionError):
  14. masks = torch.rand((10, 28, 28))
  15. mask_matrix_nms(
  16. masks,
  17. torch.rand(11),
  18. torch.rand(11),
  19. mask_area=masks.sum((1, 2)).float()[:8])
  20. with pytest.raises(NotImplementedError):
  21. mask_matrix_nms(
  22. torch.rand((10, 28, 28)),
  23. torch.rand(10),
  24. torch.rand(10),
  25. kernel='None')
  26. # test an empty results
  27. masks, labels, scores = _create_mask(0, 28, 28)
  28. score, label, mask, keep_ind = \
  29. mask_matrix_nms(masks, labels, scores)
  30. assert len(score) == len(label) == \
  31. len(mask) == len(keep_ind) == 0
  32. # do not use update_thr, nms_pre and max_num
  33. masks, labels, scores = _create_mask(1000, 28, 28)
  34. score, label, mask, keep_ind = \
  35. mask_matrix_nms(masks, labels, scores)
  36. assert len(score) == len(label) == \
  37. len(mask) == len(keep_ind) == 1000
  38. # only use nms_pre
  39. score, label, mask, keep_ind = \
  40. mask_matrix_nms(masks, labels, scores, nms_pre=500)
  41. assert len(score) == len(label) == \
  42. len(mask) == len(keep_ind) == 500
  43. # use max_num
  44. score, label, mask, keep_ind = \
  45. mask_matrix_nms(masks, labels, scores,
  46. nms_pre=500, max_num=100)
  47. assert len(score) == len(label) == \
  48. len(mask) == len(keep_ind) == 100
  49. masks, labels, _ = _create_mask(1, 28, 28)
  50. scores = torch.Tensor([1.0])
  51. masks = masks.expand(1000, 28, 28)
  52. labels = labels.expand(1000)
  53. scores = scores.expand(1000)
  54. # assert scores is decayed and update_thr is worked
  55. # if with the same mask, label, and all scores = 1
  56. # the first score will set to 1, others will decay.
  57. score, label, mask, keep_ind = \
  58. mask_matrix_nms(masks,
  59. labels,
  60. scores,
  61. nms_pre=500,
  62. max_num=100,
  63. kernel='gaussian',
  64. sigma=2.0,
  65. filter_thr=0.5)
  66. assert len(score) == 1
  67. assert score[0] == 1

No Description

Contributors (1)