You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_boxes.py 2.0 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import numpy as np
  3. import unittest
  4. import torch
  5. from detectron2.structures import Boxes, BoxMode, pairwise_iou
  6. class TestBoxMode(unittest.TestCase):
  7. def _convert_xy_to_wh(self, x):
  8. return BoxMode.convert(x, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
  9. def test_box_convert_list(self):
  10. for tp in [list, tuple]:
  11. box = tp([5, 5, 10, 10])
  12. output = self._convert_xy_to_wh(box)
  13. self.assertTrue(output == tp([5, 5, 5, 5]))
  14. with self.assertRaises(Exception):
  15. self._convert_xy_to_wh([box])
  16. def test_box_convert_array(self):
  17. box = np.asarray([[5, 5, 10, 10], [1, 1, 2, 3]])
  18. output = self._convert_xy_to_wh(box)
  19. self.assertTrue((output[0] == [5, 5, 5, 5]).all())
  20. self.assertTrue((output[1] == [1, 1, 1, 2]).all())
  21. def test_box_convert_tensor(self):
  22. box = torch.tensor([[5, 5, 10, 10], [1, 1, 2, 3]])
  23. output = self._convert_xy_to_wh(box).numpy()
  24. self.assertTrue((output[0] == [5, 5, 5, 5]).all())
  25. self.assertTrue((output[1] == [1, 1, 1, 2]).all())
  26. class TestBoxIOU(unittest.TestCase):
  27. def test_pairwise_iou(self):
  28. boxes1 = torch.tensor([[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]])
  29. boxes2 = torch.tensor(
  30. [
  31. [0.0, 0.0, 1.0, 1.0],
  32. [0.0, 0.0, 0.5, 1.0],
  33. [0.0, 0.0, 1.0, 0.5],
  34. [0.0, 0.0, 0.5, 0.5],
  35. [0.5, 0.5, 1.0, 1.0],
  36. [0.5, 0.5, 1.5, 1.5],
  37. ]
  38. )
  39. expected_ious = torch.tensor(
  40. [
  41. [1.0, 0.5, 0.5, 0.25, 0.25, 0.25 / (2 - 0.25)],
  42. [1.0, 0.5, 0.5, 0.25, 0.25, 0.25 / (2 - 0.25)],
  43. ]
  44. )
  45. ious = pairwise_iou(Boxes(boxes1), Boxes(boxes2))
  46. assert torch.allclose(ious, expected_ious)
  47. if __name__ == "__main__":
  48. unittest.main()

No Description