You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_misc.py 5.9 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import pytest
  4. import torch
  5. from mmdet.core.bbox import distance2bbox
  6. from mmdet.core.mask.structures import BitmapMasks, PolygonMasks
  7. from mmdet.core.utils import (center_of_mass, filter_scores_and_topk,
  8. flip_tensor, mask2ndarray, select_single_mlvl)
  9. def dummy_raw_polygon_masks(size):
  10. """
  11. Args:
  12. size (tuple): expected shape of dummy masks, (N, H, W)
  13. Return:
  14. list[list[ndarray]]: dummy mask
  15. """
  16. num_obj, height, width = size
  17. polygons = []
  18. for _ in range(num_obj):
  19. num_points = np.random.randint(5) * 2 + 6
  20. polygons.append([np.random.uniform(0, min(height, width), num_points)])
  21. return polygons
  22. def test_mask2ndarray():
  23. raw_masks = np.ones((3, 28, 28))
  24. bitmap_mask = BitmapMasks(raw_masks, 28, 28)
  25. output_mask = mask2ndarray(bitmap_mask)
  26. assert np.allclose(raw_masks, output_mask)
  27. raw_masks = dummy_raw_polygon_masks((3, 28, 28))
  28. polygon_masks = PolygonMasks(raw_masks, 28, 28)
  29. output_mask = mask2ndarray(polygon_masks)
  30. assert output_mask.shape == (3, 28, 28)
  31. raw_masks = np.ones((3, 28, 28))
  32. output_mask = mask2ndarray(raw_masks)
  33. assert np.allclose(raw_masks, output_mask)
  34. raw_masks = torch.ones((3, 28, 28))
  35. output_mask = mask2ndarray(raw_masks)
  36. assert np.allclose(raw_masks, output_mask)
  37. # test unsupported type
  38. raw_masks = []
  39. with pytest.raises(TypeError):
  40. output_mask = mask2ndarray(raw_masks)
  41. def test_distance2bbox():
  42. point = torch.Tensor([[74., 61.], [-29., 106.], [138., 61.], [29., 170.]])
  43. distance = torch.Tensor([[0., 0, 1., 1.], [1., 2., 10., 6.],
  44. [22., -29., 138., 61.], [54., -29., 170., 61.]])
  45. expected_decode_bboxes = torch.Tensor([[74., 61., 75., 62.],
  46. [0., 104., 0., 112.],
  47. [100., 90., 100., 120.],
  48. [0., 120., 100., 120.]])
  49. out_bbox = distance2bbox(point, distance, max_shape=(120, 100))
  50. assert expected_decode_bboxes.allclose(out_bbox)
  51. out = distance2bbox(point, distance, max_shape=torch.Tensor((120, 100)))
  52. assert expected_decode_bboxes.allclose(out)
  53. batch_point = point.unsqueeze(0).repeat(2, 1, 1)
  54. batch_distance = distance.unsqueeze(0).repeat(2, 1, 1)
  55. batch_out = distance2bbox(
  56. batch_point, batch_distance, max_shape=(120, 100))[0]
  57. assert out.allclose(batch_out)
  58. batch_out = distance2bbox(
  59. batch_point, batch_distance, max_shape=[(120, 100), (120, 100)])[0]
  60. assert out.allclose(batch_out)
  61. batch_out = distance2bbox(point, batch_distance, max_shape=(120, 100))[0]
  62. assert out.allclose(batch_out)
  63. # test max_shape is not equal to batch
  64. with pytest.raises(AssertionError):
  65. distance2bbox(
  66. batch_point,
  67. batch_distance,
  68. max_shape=[(120, 100), (120, 100), (32, 32)])
  69. rois = torch.zeros((0, 4))
  70. deltas = torch.zeros((0, 4))
  71. out = distance2bbox(rois, deltas, max_shape=(120, 100))
  72. assert rois.shape == out.shape
  73. rois = torch.zeros((2, 0, 4))
  74. deltas = torch.zeros((2, 0, 4))
  75. out = distance2bbox(rois, deltas, max_shape=(120, 100))
  76. assert rois.shape == out.shape
  77. @pytest.mark.parametrize('mask', [
  78. torch.ones((28, 28)),
  79. torch.zeros((28, 28)),
  80. torch.rand(28, 28) > 0.5,
  81. torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])
  82. ])
  83. def test_center_of_mass(mask):
  84. center_h, center_w = center_of_mass(mask)
  85. if mask.shape[0] == 4:
  86. assert center_h == 1.5
  87. assert center_w == 1.5
  88. assert isinstance(center_h, torch.Tensor) \
  89. and isinstance(center_w, torch.Tensor)
  90. assert 0 <= center_h <= 28 \
  91. and 0 <= center_w <= 28
  92. def test_flip_tensor():
  93. img = np.random.random((1, 3, 10, 10))
  94. src_tensor = torch.from_numpy(img)
  95. # test flip_direction parameter error
  96. with pytest.raises(AssertionError):
  97. flip_tensor(src_tensor, 'flip')
  98. # test tensor dimension
  99. with pytest.raises(AssertionError):
  100. flip_tensor(src_tensor[0], 'vertical')
  101. hfilp_tensor = flip_tensor(src_tensor, 'horizontal')
  102. expected_hflip_tensor = torch.from_numpy(img[..., ::-1, :].copy())
  103. expected_hflip_tensor.allclose(hfilp_tensor)
  104. vfilp_tensor = flip_tensor(src_tensor, 'vertical')
  105. expected_vflip_tensor = torch.from_numpy(img[..., ::-1].copy())
  106. expected_vflip_tensor.allclose(vfilp_tensor)
  107. diag_filp_tensor = flip_tensor(src_tensor, 'diagonal')
  108. expected_diag_filp_tensor = torch.from_numpy(img[..., ::-1, ::-1].copy())
  109. expected_diag_filp_tensor.allclose(diag_filp_tensor)
  110. def test_select_single_mlvl():
  111. mlvl_tensors = [torch.rand(2, 1, 10, 10)] * 5
  112. mlvl_tensor_list = select_single_mlvl(mlvl_tensors, 1)
  113. assert len(mlvl_tensor_list) == 5 and mlvl_tensor_list[0].ndim == 3
  114. def test_filter_scores_and_topk():
  115. score = torch.tensor([[0.1, 0.3, 0.2], [0.12, 0.7, 0.9], [0.02, 0.8, 0.08],
  116. [0.4, 0.1, 0.08]])
  117. bbox_pred = torch.tensor([[0.2, 0.3], [0.4, 0.7], [0.1, 0.1], [0.5, 0.1]])
  118. score_thr = 0.15
  119. nms_pre = 4
  120. # test results type error
  121. with pytest.raises(NotImplementedError):
  122. filter_scores_and_topk(score, score_thr, nms_pre, (score, ))
  123. filtered_results = filter_scores_and_topk(
  124. score, score_thr, nms_pre, results=dict(bbox_pred=bbox_pred))
  125. filtered_score, labels, keep_idxs, results = filtered_results
  126. assert filtered_score.allclose(torch.tensor([0.9, 0.8, 0.7, 0.4]))
  127. assert labels.allclose(torch.tensor([2, 1, 1, 0]))
  128. assert keep_idxs.allclose(torch.tensor([1, 2, 1, 3]))
  129. assert results['bbox_pred'].allclose(
  130. torch.tensor([[0.4, 0.7], [0.1, 0.1], [0.4, 0.7], [0.5, 0.1]]))

No Description

Contributors (2)