You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_visualizer.py 5.5 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # -*- coding: utf-8 -*-
  2. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  3. # File:
  4. import numpy as np
  5. import unittest
  6. import torch
  7. from detectron2.data import MetadataCatalog
  8. from detectron2.structures import BoxMode, Instances, RotatedBoxes
  9. from detectron2.utils.visualizer import Visualizer
  10. class TestVisualizer(unittest.TestCase):
  11. def _random_data(self):
  12. H, W = 100, 100
  13. N = 10
  14. img = np.random.rand(H, W, 3) * 255
  15. boxxy = np.random.rand(N, 2) * (H // 2)
  16. boxes = np.concatenate((boxxy, boxxy + H // 2), axis=1)
  17. def _rand_poly():
  18. return np.random.rand(3, 2).flatten() * H
  19. polygons = [[_rand_poly() for _ in range(np.random.randint(1, 5))] for _ in range(N)]
  20. mask = np.zeros_like(img[:, :, 0], dtype=np.bool)
  21. mask[:10, 10:20] = 1
  22. labels = [str(i) for i in range(N)]
  23. return img, boxes, labels, polygons, [mask] * N
  24. @property
  25. def metadata(self):
  26. return MetadataCatalog.get("coco_2017_train")
  27. def test_draw_dataset_dict(self):
  28. img = np.random.rand(512, 512, 3) * 255
  29. dic = {
  30. "annotations": [
  31. {
  32. "bbox": [
  33. 368.9946492271106,
  34. 330.891438763377,
  35. 13.148537455410235,
  36. 13.644708680142685,
  37. ],
  38. "bbox_mode": BoxMode.XYWH_ABS,
  39. "category_id": 0,
  40. "iscrowd": 1,
  41. "segmentation": {
  42. "counts": "_jh52m?2N2N2N2O100O10O001N1O2MceP2",
  43. "size": [512, 512],
  44. },
  45. }
  46. ],
  47. "height": 512,
  48. "image_id": 1,
  49. "width": 512,
  50. }
  51. v = Visualizer(img, self.metadata)
  52. v.draw_dataset_dict(dic)
  53. def test_overlay_instances(self):
  54. img, boxes, labels, polygons, masks = self._random_data()
  55. v = Visualizer(img, self.metadata)
  56. output = v.overlay_instances(masks=polygons, boxes=boxes, labels=labels).get_image()
  57. self.assertEqual(output.shape, img.shape)
  58. # Test 2x scaling
  59. v = Visualizer(img, self.metadata, scale=2.0)
  60. output = v.overlay_instances(masks=polygons, boxes=boxes, labels=labels).get_image()
  61. self.assertEqual(output.shape[0], img.shape[0] * 2)
  62. # Test overlay masks
  63. v = Visualizer(img, self.metadata)
  64. output = v.overlay_instances(masks=masks, boxes=boxes, labels=labels).get_image()
  65. self.assertEqual(output.shape, img.shape)
  66. def test_overlay_instances_no_boxes(self):
  67. img, boxes, labels, polygons, _ = self._random_data()
  68. v = Visualizer(img, self.metadata)
  69. v.overlay_instances(masks=polygons, boxes=None, labels=labels).get_image()
  70. def test_draw_instance_predictions(self):
  71. img, boxes, _, _, masks = self._random_data()
  72. num_inst = len(boxes)
  73. inst = Instances((img.shape[0], img.shape[1]))
  74. inst.pred_classes = torch.randint(0, 80, size=(num_inst,))
  75. inst.scores = torch.rand(num_inst)
  76. inst.pred_boxes = torch.from_numpy(boxes)
  77. inst.pred_masks = torch.from_numpy(np.asarray(masks))
  78. v = Visualizer(img, self.metadata)
  79. v.draw_instance_predictions(inst)
  80. def test_draw_empty_mask_predictions(self):
  81. img, boxes, _, _, masks = self._random_data()
  82. num_inst = len(boxes)
  83. inst = Instances((img.shape[0], img.shape[1]))
  84. inst.pred_classes = torch.randint(0, 80, size=(num_inst,))
  85. inst.scores = torch.rand(num_inst)
  86. inst.pred_boxes = torch.from_numpy(boxes)
  87. inst.pred_masks = torch.from_numpy(np.zeros_like(np.asarray(masks)))
  88. v = Visualizer(img, self.metadata)
  89. v.draw_instance_predictions(inst)
  90. def test_correct_output_shape(self):
  91. img = np.random.rand(928, 928, 3) * 255
  92. v = Visualizer(img, self.metadata)
  93. out = v.output.get_image()
  94. self.assertEqual(out.shape, img.shape)
  95. def test_overlay_rotated_instances(self):
  96. H, W = 100, 150
  97. img = np.random.rand(H, W, 3) * 255
  98. num_boxes = 50
  99. boxes_5d = torch.zeros(num_boxes, 5)
  100. boxes_5d[:, 0] = torch.FloatTensor(num_boxes).uniform_(-0.1 * W, 1.1 * W)
  101. boxes_5d[:, 1] = torch.FloatTensor(num_boxes).uniform_(-0.1 * H, 1.1 * H)
  102. boxes_5d[:, 2] = torch.FloatTensor(num_boxes).uniform_(0, max(W, H))
  103. boxes_5d[:, 3] = torch.FloatTensor(num_boxes).uniform_(0, max(W, H))
  104. boxes_5d[:, 4] = torch.FloatTensor(num_boxes).uniform_(-1800, 1800)
  105. rotated_boxes = RotatedBoxes(boxes_5d)
  106. labels = [str(i) for i in range(num_boxes)]
  107. v = Visualizer(img, self.metadata)
  108. output = v.overlay_instances(boxes=rotated_boxes, labels=labels).get_image()
  109. self.assertEqual(output.shape, img.shape)
  110. def test_draw_no_metadata(self):
  111. img, boxes, _, _, masks = self._random_data()
  112. num_inst = len(boxes)
  113. inst = Instances((img.shape[0], img.shape[1]))
  114. inst.pred_classes = torch.randint(0, 80, size=(num_inst,))
  115. inst.scores = torch.rand(num_inst)
  116. inst.pred_boxes = torch.from_numpy(boxes)
  117. inst.pred_masks = torch.from_numpy(np.asarray(masks))
  118. v = Visualizer(img, MetadataCatalog.get("asdfasdf"))
  119. v.draw_instance_predictions(inst)

No Description