You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bounding_box_augment.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing the bounding box augment op in DE
  17. """
  18. from enum import Enum
  19. import mindspore.log as logger
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.transforms.vision.c_transforms as c_vision
  22. import matplotlib.pyplot as plt
  23. import matplotlib.patches as patches
  24. import numpy as np
  25. GENERATE_GOLDEN = False
  26. DATA_DIR = "../data/dataset/testVOC2012_2"
  27. class BoxType(Enum):
  28. """
  29. Defines box types for test cases
  30. """
  31. WidthOverflow = 1
  32. HeightOverflow = 2
  33. NegativeXY = 3
  34. OnEdge = 4
  35. WrongShape = 5
  36. def add_bad_annotation(img, bboxes, box_type):
  37. """
  38. Used to generate erroneous bounding box examples on given img.
  39. :param img: image where the bounding boxes are.
  40. :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format
  41. :param box_type: type of bad box
  42. :return: bboxes with bad examples added
  43. """
  44. height = img.shape[0]
  45. width = img.shape[1]
  46. if box_type == BoxType.WidthOverflow:
  47. # use box that overflows on width
  48. return img, np.array([[0, 0, width + 1, height, 0, 0, 0]]).astype(np.uint32)
  49. if box_type == BoxType.HeightOverflow:
  50. # use box that overflows on height
  51. return img, np.array([[0, 0, width, height + 1, 0, 0, 0]]).astype(np.uint32)
  52. if box_type == BoxType.NegativeXY:
  53. # use box with negative xy
  54. return img, np.array([[-10, -10, width, height, 0, 0, 0]]).astype(np.uint32)
  55. if box_type == BoxType.OnEdge:
  56. # use box that covers the whole image
  57. return img, np.array([[0, 0, width, height, 0, 0, 0]]).astype(np.uint32)
  58. if box_type == BoxType.WrongShape:
  59. # use box that covers the whole image
  60. return img, np.array([[0, 0, width - 1]]).astype(np.uint32)
  61. return img, bboxes
  62. def check_bad_box(data, box_type, expected_error):
  63. """
  64. :param data: de object detection pipeline
  65. :param box_type: type of bad box
  66. :param expected_error: error expected to get due to bad box
  67. :return: None
  68. """
  69. try:
  70. test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1),
  71. 1) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
  72. data = data.map(input_columns=["annotation"],
  73. output_columns=["annotation"],
  74. operations=fix_annotate)
  75. # map to use width overflow
  76. data = data.map(input_columns=["image", "annotation"],
  77. output_columns=["image", "annotation"],
  78. columns_order=["image", "annotation"],
  79. operations=lambda img, bboxes: add_bad_annotation(img, bboxes, box_type))
  80. # map to apply ops
  81. data = data.map(input_columns=["image", "annotation"],
  82. output_columns=["image", "annotation"],
  83. columns_order=["image", "annotation"],
  84. operations=[test_op]) # Add column for "annotation"
  85. for _, _ in enumerate(data.create_dict_iterator()):
  86. break
  87. except RuntimeError as error:
  88. logger.info("Got an exception in DE: {}".format(str(error)))
  89. assert expected_error in str(error)
  90. def fix_annotate(bboxes):
  91. """
  92. Fix annotations to format followed by mindspore.
  93. :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format
  94. :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format
  95. """
  96. for bbox in bboxes:
  97. tmp = bbox[0]
  98. bbox[0] = bbox[1]
  99. bbox[1] = bbox[2]
  100. bbox[2] = bbox[3]
  101. bbox[3] = bbox[4]
  102. bbox[4] = tmp
  103. return bboxes
  104. def add_bounding_boxes(axis, bboxes):
  105. """
  106. :param axis: axis to modify
  107. :param bboxes: bounding boxes to draw on the axis
  108. :return: None
  109. """
  110. for bbox in bboxes:
  111. rect = patches.Rectangle((bbox[0], bbox[1]),
  112. bbox[2], bbox[3],
  113. linewidth=1, edgecolor='r', facecolor='none')
  114. # Add the patch to the Axes
  115. axis.add_patch(rect)
  116. def visualize(unaugmented_data, augment_data):
  117. """
  118. :param unaugmented_data: original data
  119. :param augment_data: data after augmentations
  120. :return: None
  121. """
  122. for idx, (un_aug_item, aug_item) in \
  123. enumerate(zip(unaugmented_data.create_dict_iterator(),
  124. augment_data.create_dict_iterator())):
  125. axis = plt.subplot(141)
  126. plt.imshow(un_aug_item["image"])
  127. add_bounding_boxes(axis, un_aug_item["annotation"]) # add Orig BBoxes
  128. plt.title("Original" + str(idx + 1))
  129. logger.info("Original ", str(idx + 1), " :", un_aug_item["annotation"])
  130. axis = plt.subplot(142)
  131. plt.imshow(aug_item["image"])
  132. add_bounding_boxes(axis, aug_item["annotation"]) # add AugBBoxes
  133. plt.title("Augmented" + str(idx + 1))
  134. logger.info("Augmented ", str(idx + 1), " ", aug_item["annotation"], "\n")
  135. plt.show()
  136. def test_bounding_box_augment_with_rotation_op(plot=False):
  137. """
  138. Test BoundingBoxAugment op
  139. Prints images side by side with and without Aug applied + bboxes to compare and test
  140. """
  141. logger.info("test_bounding_box_augment_with_rotation_op")
  142. data_voc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  143. data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  144. test_op = c_vision.BoundingBoxAugment(c_vision.RandomRotation(90), 1)
  145. # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
  146. # maps to fix annotations to minddata standard
  147. data_voc1 = data_voc1.map(input_columns=["annotation"],
  148. output_columns=["annotation"],
  149. operations=fix_annotate)
  150. data_voc2 = data_voc2.map(input_columns=["annotation"],
  151. output_columns=["annotation"],
  152. operations=fix_annotate)
  153. # map to apply ops
  154. data_voc2 = data_voc2.map(input_columns=["image", "annotation"],
  155. output_columns=["image", "annotation"],
  156. columns_order=["image", "annotation"],
  157. operations=[test_op]) # Add column for "annotation"
  158. if plot:
  159. visualize(data_voc1, data_voc2)
  160. def test_bounding_box_augment_with_crop_op(plot=False):
  161. """
  162. Test BoundingBoxAugment op
  163. Prints images side by side with and without Aug applied + bboxes to compare and test
  164. """
  165. logger.info("test_bounding_box_augment_with_crop_op")
  166. data_voc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  167. data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  168. test_op = c_vision.BoundingBoxAugment(c_vision.RandomCrop(90), 1)
  169. # maps to fix annotations to minddata standard
  170. data_voc1 = data_voc1.map(input_columns=["annotation"],
  171. output_columns=["annotation"],
  172. operations=fix_annotate)
  173. data_voc2 = data_voc2.map(input_columns=["annotation"],
  174. output_columns=["annotation"],
  175. operations=fix_annotate)
  176. # map to apply ops
  177. data_voc2 = data_voc2.map(input_columns=["image", "annotation"],
  178. output_columns=["image", "annotation"],
  179. columns_order=["image", "annotation"],
  180. operations=[test_op]) # Add column for "annotation"
  181. if plot:
  182. visualize(data_voc1, data_voc2)
  183. def test_bounding_box_augment_valid_ratio_c(plot=False):
  184. """
  185. Test RandomHorizontalFlipWithBBox op
  186. Prints images side by side with and without Aug applied + bboxes to compare and test
  187. """
  188. logger.info("test_bounding_box_augment_valid_ratio_c")
  189. data_voc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  190. data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  191. test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 0.9)
  192. # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
  193. # maps to fix annotations to minddata standard
  194. data_voc1 = data_voc1.map(input_columns=["annotation"],
  195. output_columns=["annotation"],
  196. operations=fix_annotate)
  197. data_voc2 = data_voc2.map(input_columns=["annotation"],
  198. output_columns=["annotation"],
  199. operations=fix_annotate)
  200. # map to apply ops
  201. data_voc2 = data_voc2.map(input_columns=["image", "annotation"],
  202. output_columns=["image", "annotation"],
  203. columns_order=["image", "annotation"],
  204. operations=[test_op]) # Add column for "annotation"
  205. if plot:
  206. visualize(data_voc1, data_voc2)
  207. def test_bounding_box_augment_invalid_ratio_c():
  208. """
  209. Test RandomHorizontalFlipWithBBox op with invalid input probability
  210. """
  211. logger.info("test_bounding_box_augment_invalid_ratio_c")
  212. data_voc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  213. data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  214. try:
  215. # ratio range is from 0 - 1
  216. test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1.5)
  217. # maps to fix annotations to minddata standard
  218. data_voc1 = data_voc1.map(input_columns=["annotation"],
  219. output_columns=["annotation"],
  220. operations=fix_annotate)
  221. data_voc2 = data_voc2.map(input_columns=["annotation"],
  222. output_columns=["annotation"],
  223. operations=fix_annotate)
  224. # map to apply ops
  225. data_voc2 = data_voc2.map(input_columns=["image", "annotation"],
  226. output_columns=["image", "annotation"],
  227. columns_order=["image", "annotation"],
  228. operations=[test_op]) # Add column for "annotation"
  229. except ValueError as error:
  230. logger.info("Got an exception in DE: {}".format(str(error)))
  231. assert "Input is not" in str(error)
  232. def test_bounding_box_augment_invalid_bounds_c():
  233. """
  234. Test BoundingBoxAugment op with invalid bboxes.
  235. """
  236. logger.info("test_bounding_box_augment_invalid_bounds_c")
  237. data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  238. check_bad_box(data_voc2, BoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
  239. data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  240. check_bad_box(data_voc2, BoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
  241. data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  242. check_bad_box(data_voc2, BoxType.NegativeXY, "min_x")
  243. data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
  244. check_bad_box(data_voc2, BoxType.WrongShape, "4 features")
  245. if __name__ == "__main__":
  246. # set to false to not show plots
  247. test_bounding_box_augment_with_rotation_op(False)
  248. test_bounding_box_augment_with_crop_op(False)
  249. test_bounding_box_augment_valid_ratio_c(False)
  250. test_bounding_box_augment_invalid_ratio_c()
  251. test_bounding_box_augment_invalid_bounds_c()