You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_uniform_augment.py 9.8 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import matplotlib.pyplot as plt
  16. import numpy as np
  17. import mindspore.dataset.engine as de
  18. import mindspore.dataset.transforms.vision.c_transforms as C
  19. import mindspore.dataset.transforms.vision.py_transforms as F
  20. from mindspore import log as logger
  21. DATA_DIR = "../data/dataset/testImageNetData/train/"
  22. def visualize(image_original, image_ua):
  23. """
  24. visualizes the image using DE op and Numpy op
  25. """
  26. num = len(image_ua)
  27. for i in range(num):
  28. plt.subplot(2, num, i + 1)
  29. plt.imshow(image_original[i])
  30. plt.title("Original image")
  31. plt.subplot(2, num, i + num + 1)
  32. plt.imshow(image_ua[i])
  33. plt.title("DE UniformAugment image")
  34. plt.show()
  35. def test_uniform_augment(plot=False, num_ops=2):
  36. """
  37. Test UniformAugment
  38. """
  39. logger.info("Test UniformAugment")
  40. # Original Images
  41. ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
  42. transforms_original = F.ComposeOp([F.Decode(),
  43. F.Resize((224, 224)),
  44. F.ToTensor()])
  45. ds_original = ds.map(input_columns="image",
  46. operations=transforms_original())
  47. ds_original = ds_original.batch(512)
  48. for idx, (image, label) in enumerate(ds_original):
  49. if idx == 0:
  50. images_original = np.transpose(image, (0, 2, 3, 1))
  51. else:
  52. images_original = np.append(images_original,
  53. np.transpose(image, (0, 2, 3, 1)),
  54. axis=0)
  55. # UniformAugment Images
  56. ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
  57. transform_list = [F.RandomRotation(45),
  58. F.RandomColor(),
  59. F.RandomSharpness(),
  60. F.Invert(),
  61. F.AutoContrast(),
  62. F.Equalize()]
  63. transforms_ua = F.ComposeOp([F.Decode(),
  64. F.Resize((224, 224)),
  65. F.UniformAugment(transforms=transform_list, num_ops=num_ops),
  66. F.ToTensor()])
  67. ds_ua = ds.map(input_columns="image",
  68. operations=transforms_ua())
  69. ds_ua = ds_ua.batch(512)
  70. for idx, (image, label) in enumerate(ds_ua):
  71. if idx == 0:
  72. images_ua = np.transpose(image, (0, 2, 3, 1))
  73. else:
  74. images_ua = np.append(images_ua,
  75. np.transpose(image, (0, 2, 3, 1)),
  76. axis=0)
  77. num_samples = images_original.shape[0]
  78. mse = np.zeros(num_samples)
  79. for i in range(num_samples):
  80. mse[i] = np.mean((images_ua[i] - images_original[i]) ** 2)
  81. logger.info("MSE= {}".format(str(np.mean(mse))))
  82. if plot:
  83. visualize(images_original, images_ua)
  84. def test_cpp_uniform_augment(plot=False, num_ops=2):
  85. """
  86. Test UniformAugment
  87. """
  88. logger.info("Test CPP UniformAugment")
  89. # Original Images
  90. ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
  91. transforms_original = [C.Decode(), C.Resize(size=[224, 224]),
  92. F.ToTensor()]
  93. ds_original = ds.map(input_columns="image",
  94. operations=transforms_original)
  95. ds_original = ds_original.batch(512)
  96. for idx, (image, label) in enumerate(ds_original):
  97. if idx == 0:
  98. images_original = np.transpose(image, (0, 2, 3, 1))
  99. else:
  100. images_original = np.append(images_original,
  101. np.transpose(image, (0, 2, 3, 1)),
  102. axis=0)
  103. # UniformAugment Images
  104. ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
  105. transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
  106. C.RandomHorizontalFlip(),
  107. C.RandomVerticalFlip(),
  108. C.RandomColorAdjust(),
  109. C.RandomRotation(degrees=45)]
  110. uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
  111. transforms_all = [C.Decode(), C.Resize(size=[224, 224]),
  112. uni_aug,
  113. F.ToTensor()]
  114. ds_ua = ds.map(input_columns="image",
  115. operations=transforms_all, num_parallel_workers=1)
  116. ds_ua = ds_ua.batch(512)
  117. for idx, (image, label) in enumerate(ds_ua):
  118. if idx == 0:
  119. images_ua = np.transpose(image, (0, 2, 3, 1))
  120. else:
  121. images_ua = np.append(images_ua,
  122. np.transpose(image, (0, 2, 3, 1)),
  123. axis=0)
  124. if plot:
  125. visualize(images_original, images_ua)
  126. num_samples = images_original.shape[0]
  127. mse = np.zeros(num_samples)
  128. for i in range(num_samples):
  129. mse[i] = np.mean((images_ua[i] - images_original[i]) ** 2)
  130. logger.info("MSE= {}".format(str(np.mean(mse))))
  131. def test_cpp_uniform_augment_exception_pyops(num_ops=2):
  132. """
  133. Test UniformAugment invalid op in operations
  134. """
  135. logger.info("Test CPP UniformAugment invalid OP exception")
  136. transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
  137. C.RandomHorizontalFlip(),
  138. C.RandomVerticalFlip(),
  139. C.RandomColorAdjust(),
  140. C.RandomRotation(degrees=45),
  141. F.Invert()]
  142. try:
  143. uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
  144. except BaseException as e:
  145. logger.info("Got an exception in DE: {}".format(str(e)))
  146. assert "operations" in str(e)
  147. def test_cpp_uniform_augment_exception_large_numops(num_ops=6):
  148. """
  149. Test UniformAugment invalid large number of ops
  150. """
  151. logger.info("Test CPP UniformAugment invalid large num_ops exception")
  152. transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
  153. C.RandomHorizontalFlip(),
  154. C.RandomVerticalFlip(),
  155. C.RandomColorAdjust(),
  156. C.RandomRotation(degrees=45)]
  157. try:
  158. uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
  159. except BaseException as e:
  160. logger.info("Got an exception in DE: {}".format(str(e)))
  161. assert "num_ops" in str(e)
  162. def test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0):
  163. """
  164. Test UniformAugment invalid non-positive number of ops
  165. """
  166. logger.info("Test CPP UniformAugment invalid non-positive num_ops exception")
  167. transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
  168. C.RandomHorizontalFlip(),
  169. C.RandomVerticalFlip(),
  170. C.RandomColorAdjust(),
  171. C.RandomRotation(degrees=45)]
  172. try:
  173. uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
  174. except BaseException as e:
  175. logger.info("Got an exception in DE: {}".format(str(e)))
  176. assert "num_ops" in str(e)
  177. def test_cpp_uniform_augment_exception_float_numops(num_ops=2.5):
  178. """
  179. Test UniformAugment invalid float number of ops
  180. """
  181. logger.info("Test CPP UniformAugment invalid float num_ops exception")
  182. transforms_ua = [C.RandomCrop(size=[224, 224], padding=[32, 32, 32, 32]),
  183. C.RandomHorizontalFlip(),
  184. C.RandomVerticalFlip(),
  185. C.RandomColorAdjust(),
  186. C.RandomRotation(degrees=45)]
  187. try:
  188. uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
  189. except BaseException as e:
  190. logger.info("Got an exception in DE: {}".format(str(e)))
  191. assert "integer" in str(e)
  192. def test_cpp_uniform_augment_random_crop_badinput(num_ops=1):
  193. """
  194. Test UniformAugment with greater crop size
  195. """
  196. logger.info("Test CPP UniformAugment with random_crop bad input")
  197. batch_size=2
  198. cifar10_dir = "../data/dataset/testCifar10Data"
  199. ds1 = de.Cifar10Dataset(cifar10_dir, shuffle=False) # shape = [32,32,3]
  200. transforms_ua = [
  201. # Note: crop size [224, 224] > image size [32, 32]
  202. C.RandomCrop(size=[224, 224]),
  203. C.RandomHorizontalFlip()
  204. ]
  205. uni_aug = C.UniformAugment(operations=transforms_ua, num_ops=num_ops)
  206. ds1 = ds1.map(input_columns="image", operations=uni_aug)
  207. # apply DatasetOps
  208. ds1 = ds1.batch(batch_size, drop_remainder=True, num_parallel_workers=1)
  209. num_batches = 0
  210. try:
  211. for data in ds1.create_dict_iterator():
  212. num_batches += 1
  213. except BaseException as e:
  214. assert "Crop size" in str(e)
  215. if __name__ == "__main__":
  216. test_uniform_augment(num_ops=1)
  217. test_cpp_uniform_augment(num_ops=1)
  218. test_cpp_uniform_augment_exception_pyops(num_ops=1)
  219. test_cpp_uniform_augment_exception_large_numops(num_ops=6)
  220. test_cpp_uniform_augment_exception_nonpositive_numops(num_ops=0)
  221. test_cpp_uniform_augment_exception_float_numops(num_ops=2.5)
  222. test_cpp_uniform_augment_random_crop_badinput(num_ops=1)