You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_visualization.py 4.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os
  3. import os.path as osp
  4. import tempfile
  5. import mmcv
  6. import numpy as np
  7. import pytest
  8. import torch
  9. from mmdet.core import visualization as vis
  10. def test_color():
  11. assert vis.color_val_matplotlib(mmcv.Color.blue) == (0., 0., 1.)
  12. assert vis.color_val_matplotlib('green') == (0., 1., 0.)
  13. assert vis.color_val_matplotlib((1, 2, 3)) == (3 / 255, 2 / 255, 1 / 255)
  14. assert vis.color_val_matplotlib(100) == (100 / 255, 100 / 255, 100 / 255)
  15. assert vis.color_val_matplotlib(np.zeros(3, dtype=np.int)) == (0., 0., 0.)
  16. # forbid white color
  17. with pytest.raises(TypeError):
  18. vis.color_val_matplotlib([255, 255, 255])
  19. # forbid float
  20. with pytest.raises(TypeError):
  21. vis.color_val_matplotlib(1.0)
  22. # overflowed
  23. with pytest.raises(AssertionError):
  24. vis.color_val_matplotlib((0, 0, 500))
  25. def test_imshow_det_bboxes():
  26. tmp_filename = osp.join(tempfile.gettempdir(), 'det_bboxes_image',
  27. 'image.jpg')
  28. image = np.ones((10, 10, 3), np.uint8)
  29. bbox = np.array([[2, 1, 3, 3], [3, 4, 6, 6]])
  30. label = np.array([0, 1])
  31. out_image = vis.imshow_det_bboxes(
  32. image, bbox, label, out_file=tmp_filename, show=False)
  33. assert osp.isfile(tmp_filename)
  34. assert image.shape == out_image.shape
  35. assert not np.allclose(image, out_image)
  36. os.remove(tmp_filename)
  37. # test grayscale images
  38. image = np.ones((10, 10), np.uint8)
  39. bbox = np.array([[2, 1, 3, 3], [3, 4, 6, 6]])
  40. label = np.array([0, 1])
  41. out_image = vis.imshow_det_bboxes(
  42. image, bbox, label, out_file=tmp_filename, show=False)
  43. assert osp.isfile(tmp_filename)
  44. assert image.shape == out_image.shape[:2]
  45. os.remove(tmp_filename)
  46. # test shaped (0,)
  47. image = np.ones((10, 10, 3), np.uint8)
  48. bbox = np.ones((0, 4))
  49. label = np.ones((0, ))
  50. vis.imshow_det_bboxes(
  51. image, bbox, label, out_file=tmp_filename, show=False)
  52. assert osp.isfile(tmp_filename)
  53. os.remove(tmp_filename)
  54. # test mask
  55. image = np.ones((10, 10, 3), np.uint8)
  56. bbox = np.array([[2, 1, 3, 3], [3, 4, 6, 6]])
  57. label = np.array([0, 1])
  58. segms = np.random.random((2, 10, 10)) > 0.5
  59. segms = np.array(segms, np.int32)
  60. vis.imshow_det_bboxes(
  61. image, bbox, label, segms, out_file=tmp_filename, show=False)
  62. assert osp.isfile(tmp_filename)
  63. os.remove(tmp_filename)
  64. # test tensor mask type error
  65. with pytest.raises(AttributeError):
  66. segms = torch.tensor(segms)
  67. vis.imshow_det_bboxes(image, bbox, label, segms, show=False)
  68. def test_imshow_gt_det_bboxes():
  69. tmp_filename = osp.join(tempfile.gettempdir(), 'det_bboxes_image',
  70. 'image.jpg')
  71. image = np.ones((10, 10, 3), np.uint8)
  72. bbox = np.array([[2, 1, 3, 3], [3, 4, 6, 6]])
  73. label = np.array([0, 1])
  74. annotation = dict(gt_bboxes=bbox, gt_labels=label)
  75. det_result = np.array([[2, 1, 3, 3, 0], [3, 4, 6, 6, 1]])
  76. result = [det_result]
  77. out_image = vis.imshow_gt_det_bboxes(
  78. image, annotation, result, out_file=tmp_filename, show=False)
  79. assert osp.isfile(tmp_filename)
  80. assert image.shape == out_image.shape
  81. assert not np.allclose(image, out_image)
  82. os.remove(tmp_filename)
  83. # test grayscale images
  84. image = np.ones((10, 10), np.uint8)
  85. bbox = np.array([[2, 1, 3, 3], [3, 4, 6, 6]])
  86. label = np.array([0, 1])
  87. annotation = dict(gt_bboxes=bbox, gt_labels=label)
  88. det_result = np.array([[2, 1, 3, 3, 0], [3, 4, 6, 6, 1]])
  89. result = [det_result]
  90. vis.imshow_gt_det_bboxes(
  91. image, annotation, result, out_file=tmp_filename, show=False)
  92. assert osp.isfile(tmp_filename)
  93. os.remove(tmp_filename)
  94. # test numpy mask
  95. gt_mask = np.ones((2, 10, 10))
  96. annotation['gt_masks'] = gt_mask
  97. vis.imshow_gt_det_bboxes(
  98. image, annotation, result, out_file=tmp_filename, show=False)
  99. assert osp.isfile(tmp_filename)
  100. os.remove(tmp_filename)
  101. # test tensor mask
  102. gt_mask = torch.ones((2, 10, 10))
  103. annotation['gt_masks'] = gt_mask
  104. vis.imshow_gt_det_bboxes(
  105. image, annotation, result, out_file=tmp_filename, show=False)
  106. assert osp.isfile(tmp_filename)
  107. os.remove(tmp_filename)
  108. # test unsupported type
  109. annotation['gt_masks'] = []
  110. with pytest.raises(TypeError):
  111. vis.imshow_gt_det_bboxes(image, annotation, result, show=False)

No Description

Contributors (2)