You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_box2box_transform.py 2.2 kB

3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import logging
  3. import unittest
  4. import torch
  5. from detectron2.modeling.box_regression import Box2BoxTransform, Box2BoxTransformRotated
  6. logger = logging.getLogger(__name__)
  7. def random_boxes(mean_box, stdev, N):
  8. return torch.rand(N, 4) * stdev + torch.tensor(mean_box, dtype=torch.float)
  9. class TestBox2BoxTransform(unittest.TestCase):
  10. def test_reconstruction(self):
  11. weights = (5, 5, 10, 10)
  12. b2b_tfm = Box2BoxTransform(weights=weights)
  13. src_boxes = random_boxes([10, 10, 20, 20], 1, 10)
  14. dst_boxes = random_boxes([10, 10, 20, 20], 1, 10)
  15. devices = [torch.device("cpu")]
  16. if torch.cuda.is_available():
  17. devices.append(torch.device("cuda"))
  18. for device in devices:
  19. src_boxes = src_boxes.to(device=device)
  20. dst_boxes = dst_boxes.to(device=device)
  21. deltas = b2b_tfm.get_deltas(src_boxes, dst_boxes)
  22. dst_boxes_reconstructed = b2b_tfm.apply_deltas(deltas, src_boxes)
  23. assert torch.allclose(dst_boxes, dst_boxes_reconstructed)
  24. def random_rotated_boxes(mean_box, std_length, std_angle, N):
  25. return torch.cat(
  26. [torch.rand(N, 4) * std_length, torch.rand(N, 1) * std_angle], dim=1
  27. ) + torch.tensor(mean_box, dtype=torch.float)
  28. class TestBox2BoxTransformRotated(unittest.TestCase):
  29. def test_reconstruction(self):
  30. weights = (5, 5, 10, 10, 1)
  31. b2b_transform = Box2BoxTransformRotated(weights=weights)
  32. src_boxes = random_rotated_boxes([10, 10, 20, 20, -30], 5, 60.0, 10)
  33. dst_boxes = random_rotated_boxes([10, 10, 20, 20, -30], 5, 60.0, 10)
  34. devices = [torch.device("cpu")]
  35. if torch.cuda.is_available():
  36. devices.append(torch.device("cuda"))
  37. for device in devices:
  38. src_boxes = src_boxes.to(device=device)
  39. dst_boxes = dst_boxes.to(device=device)
  40. deltas = b2b_transform.get_deltas(src_boxes, dst_boxes)
  41. dst_boxes_reconstructed = b2b_transform.apply_deltas(deltas, src_boxes)
  42. assert torch.allclose(dst_boxes, dst_boxes_reconstructed, atol=1e-5)
  43. if __name__ == "__main__":
  44. unittest.main()

No Description