You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bounding_box_augment.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing the bounding box augment op in DE
  17. """
  18. import numpy as np
  19. import mindspore.log as logger
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.transforms.vision.c_transforms as c_vision
  22. from util import visualize_with_bounding_boxes, InvalidBBoxType, check_bad_bbox, \
  23. config_get_set_seed, config_get_set_num_parallel_workers, save_and_check_md5
  24. GENERATE_GOLDEN = False
  25. # updated VOC dataset with correct annotations
  26. DATA_DIR = "../data/dataset/testVOC2012_2"
  27. DATA_DIR_2 = ["../data/dataset/testCOCO/train/",
  28. "../data/dataset/testCOCO/annotations/train.json"] # DATA_DIR, ANNOTATION_DIR
  29. def test_bounding_box_augment_with_rotation_op(plot_vis=False):
  30. """
  31. Test BoundingBoxAugment op (passing rotation op as transform)
  32. Prints images side by side with and without Aug applied + bboxes to compare and test
  33. """
  34. logger.info("test_bounding_box_augment_with_rotation_op")
  35. original_seed = config_get_set_seed(0)
  36. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  37. dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  38. dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  39. # Ratio is set to 1 to apply rotation on all bounding boxes.
  40. test_op = c_vision.BoundingBoxAugment(c_vision.RandomRotation(90), 1)
  41. # map to apply ops
  42. dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
  43. output_columns=["image", "bbox"],
  44. columns_order=["image", "bbox"],
  45. operations=[test_op])
  46. filename = "bounding_box_augment_rotation_c_result.npz"
  47. save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN)
  48. unaugSamp, augSamp = [], []
  49. for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()):
  50. unaugSamp.append(unAug)
  51. augSamp.append(Aug)
  52. if plot_vis:
  53. visualize_with_bounding_boxes(unaugSamp, augSamp)
  54. # Restore config setting
  55. ds.config.set_seed(original_seed)
  56. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  57. def test_bounding_box_augment_with_crop_op(plot_vis=False):
  58. """
  59. Test BoundingBoxAugment op (passing crop op as transform)
  60. Prints images side by side with and without Aug applied + bboxes to compare and test
  61. """
  62. logger.info("test_bounding_box_augment_with_crop_op")
  63. original_seed = config_get_set_seed(0)
  64. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  65. dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  66. dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  67. # Ratio is set to 0.9 to apply RandomCrop of size (50, 50) on 90% of the bounding boxes.
  68. test_op = c_vision.BoundingBoxAugment(c_vision.RandomCrop(50), 0.9)
  69. # map to apply ops
  70. dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
  71. output_columns=["image", "bbox"],
  72. columns_order=["image", "bbox"],
  73. operations=[test_op])
  74. filename = "bounding_box_augment_crop_c_result.npz"
  75. save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN)
  76. unaugSamp, augSamp = [], []
  77. for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()):
  78. unaugSamp.append(unAug)
  79. augSamp.append(Aug)
  80. if plot_vis:
  81. visualize_with_bounding_boxes(unaugSamp, augSamp)
  82. # Restore config setting
  83. ds.config.set_seed(original_seed)
  84. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  85. def test_bounding_box_augment_valid_ratio_c(plot_vis=False):
  86. """
  87. Test BoundingBoxAugment op (testing with valid ratio, less than 1.
  88. Prints images side by side with and without Aug applied + bboxes to compare and test
  89. """
  90. logger.info("test_bounding_box_augment_valid_ratio_c")
  91. original_seed = config_get_set_seed(1)
  92. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  93. dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  94. dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  95. test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 0.9)
  96. # map to apply ops
  97. dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
  98. output_columns=["image", "bbox"],
  99. columns_order=["image", "bbox"],
  100. operations=[test_op]) # Add column for "bbox"
  101. filename = "bounding_box_augment_valid_ratio_c_result.npz"
  102. save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN)
  103. unaugSamp, augSamp = [], []
  104. for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()):
  105. unaugSamp.append(unAug)
  106. augSamp.append(Aug)
  107. if plot_vis:
  108. visualize_with_bounding_boxes(unaugSamp, augSamp)
  109. # Restore config setting
  110. ds.config.set_seed(original_seed)
  111. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  112. def test_bounding_box_augment_op_coco_c(plot_vis=False):
  113. """
  114. Prints images and bboxes side by side with and without BoundingBoxAugment Op applied,
  115. Testing with COCO dataset
  116. """
  117. logger.info("test_bounding_box_augment_op_coco_c")
  118. dataCoco1 = ds.CocoDataset(DATA_DIR_2[0], annotation_file=DATA_DIR_2[1], task="Detection",
  119. decode=True, shuffle=False)
  120. dataCoco2 = ds.CocoDataset(DATA_DIR_2[0], annotation_file=DATA_DIR_2[1], task="Detection",
  121. decode=True, shuffle=False)
  122. test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1)
  123. dataCoco2 = dataCoco2.map(input_columns=["image", "bbox"],
  124. output_columns=["image", "bbox"],
  125. columns_order=["image", "bbox"],
  126. operations=[test_op])
  127. unaugSamp, augSamp = [], []
  128. for unAug, Aug in zip(dataCoco1.create_dict_iterator(), dataCoco2.create_dict_iterator()):
  129. unaugSamp.append(unAug)
  130. augSamp.append(Aug)
  131. if plot_vis:
  132. visualize_with_bounding_boxes(unaugSamp, augSamp, "bbox")
  133. def test_bounding_box_augment_valid_edge_c(plot_vis=False):
  134. """
  135. Test BoundingBoxAugment op (testing with valid edge case, box covering full image).
  136. Prints images side by side with and without Aug applied + bboxes to compare and test
  137. """
  138. logger.info("test_bounding_box_augment_valid_edge_c")
  139. original_seed = config_get_set_seed(1)
  140. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  141. dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  142. dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  143. test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1)
  144. # map to apply ops
  145. # Add column for "bbox"
  146. dataVoc1 = dataVoc1.map(input_columns=["image", "bbox"],
  147. output_columns=["image", "bbox"],
  148. columns_order=["image", "bbox"],
  149. operations=lambda img, bbox:
  150. (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32)))
  151. dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
  152. output_columns=["image", "bbox"],
  153. columns_order=["image", "bbox"],
  154. operations=lambda img, bbox:
  155. (img, np.array([[0, 0, img.shape[1], img.shape[0], 0, 0, 0]]).astype(np.float32)))
  156. dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
  157. output_columns=["image", "bbox"],
  158. columns_order=["image", "bbox"],
  159. operations=[test_op])
  160. filename = "bounding_box_augment_valid_edge_c_result.npz"
  161. save_and_check_md5(dataVoc2, filename, generate_golden=GENERATE_GOLDEN)
  162. unaugSamp, augSamp = [], []
  163. for unAug, Aug in zip(dataVoc1.create_dict_iterator(), dataVoc2.create_dict_iterator()):
  164. unaugSamp.append(unAug)
  165. augSamp.append(Aug)
  166. if plot_vis:
  167. visualize_with_bounding_boxes(unaugSamp, augSamp)
  168. # Restore config setting
  169. ds.config.set_seed(original_seed)
  170. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  171. def test_bounding_box_augment_invalid_ratio_c():
  172. """
  173. Test BoundingBoxAugment op with invalid input ratio
  174. """
  175. logger.info("test_bounding_box_augment_invalid_ratio_c")
  176. dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  177. try:
  178. # ratio range is from 0 - 1
  179. test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1.5)
  180. # map to apply ops
  181. dataVoc2 = dataVoc2.map(input_columns=["image", "bbox"],
  182. output_columns=["image", "bbox"],
  183. columns_order=["image", "bbox"],
  184. operations=[test_op]) # Add column for "bbox"
  185. except ValueError as error:
  186. logger.info("Got an exception in DE: {}".format(str(error)))
  187. assert "Input ratio is not within the required interval of (0.0 to 1.0)." in str(error)
  188. def test_bounding_box_augment_invalid_bounds_c():
  189. """
  190. Test BoundingBoxAugment op with invalid bboxes.
  191. """
  192. logger.info("test_bounding_box_augment_invalid_bounds_c")
  193. test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1),
  194. 1)
  195. dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  196. check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
  197. dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  198. check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
  199. dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  200. check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.NegativeXY, "min_x")
  201. dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  202. check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WrongShape, "4 features")
  203. if __name__ == "__main__":
  204. # set to false to not show plots
  205. test_bounding_box_augment_with_rotation_op(plot_vis=False)
  206. test_bounding_box_augment_with_crop_op(plot_vis=False)
  207. test_bounding_box_augment_op_coco_c(plot_vis=False)
  208. test_bounding_box_augment_valid_ratio_c(plot_vis=False)
  209. test_bounding_box_augment_valid_edge_c(plot_vis=False)
  210. test_bounding_box_augment_invalid_ratio_c()
  211. test_bounding_box_augment_invalid_bounds_c()