You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_data_transform.py 2.9 kB

3 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # -*- coding: utf-8 -*-
  2. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  3. import logging
  4. import numpy as np
  5. import unittest
  6. from detectron2.config import get_cfg
  7. from detectron2.data import detection_utils
  8. from detectron2.data import transforms as T
  9. from detectron2.utils.logger import setup_logger
  10. logger = logging.getLogger(__name__)
  11. class TestTransforms(unittest.TestCase):
  12. def setUp(self):
  13. setup_logger()
  14. def test_apply_rotated_boxes(self):
  15. np.random.seed(125)
  16. cfg = get_cfg()
  17. is_train = True
  18. transform_gen = detection_utils.build_transform_gen(cfg, is_train)
  19. image = np.random.rand(200, 300)
  20. image, transforms = T.apply_transform_gens(transform_gen, image)
  21. image_shape = image.shape[:2] # h, w
  22. assert image_shape == (800, 1200)
  23. annotation = {"bbox": [179, 97, 62, 40, -56]}
  24. boxes = np.array([annotation["bbox"]], dtype=np.float64) # boxes.shape = (1, 5)
  25. transformed_bbox = transforms.apply_rotated_box(boxes)[0]
  26. expected_bbox = np.array([484, 388, 248, 160, 56], dtype=np.float64)
  27. err_msg = "transformed_bbox = {}, expected {}".format(transformed_bbox, expected_bbox)
  28. assert np.allclose(transformed_bbox, expected_bbox), err_msg
  29. def test_apply_rotated_boxes_unequal_scaling_factor(self):
  30. np.random.seed(125)
  31. h, w = 400, 200
  32. newh, neww = 800, 800
  33. image = np.random.rand(h, w)
  34. transform_gen = []
  35. transform_gen.append(T.Resize(shape=(newh, neww)))
  36. image, transforms = T.apply_transform_gens(transform_gen, image)
  37. image_shape = image.shape[:2] # h, w
  38. assert image_shape == (newh, neww)
  39. boxes = np.array(
  40. [
  41. [150, 100, 40, 20, 0],
  42. [150, 100, 40, 20, 30],
  43. [150, 100, 40, 20, 90],
  44. [150, 100, 40, 20, -90],
  45. ],
  46. dtype=np.float64,
  47. )
  48. transformed_boxes = transforms.apply_rotated_box(boxes)
  49. expected_bboxes = np.array(
  50. [
  51. [600, 200, 160, 40, 0],
  52. [600, 200, 144.22205102, 52.91502622, 49.10660535],
  53. [600, 200, 80, 80, 90],
  54. [600, 200, 80, 80, -90],
  55. ],
  56. dtype=np.float64,
  57. )
  58. err_msg = "transformed_boxes = {}, expected {}".format(transformed_boxes, expected_bboxes)
  59. assert np.allclose(transformed_boxes, expected_bboxes), err_msg
  60. def test_print_transform_gen(self):
  61. t = T.RandomCrop("relative", (100, 100))
  62. self.assertTrue(str(t) == "RandomCrop(crop_type='relative', crop_size=(100, 100))")
  63. t = T.RandomFlip(prob=0.5)
  64. self.assertTrue(str(t) == "RandomFlip(prob=0.5)")
  65. t = T.RandomFlip()
  66. self.assertTrue(str(t) == "RandomFlip()")

No Description