You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_mixup_op.py 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago

  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing the MixUpBatch op in DE
  17. """
  18. import numpy as np
  19. import pytest
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.transforms.vision.c_transforms as vision
  22. import mindspore.dataset.transforms.c_transforms as data_trans
  23. from mindspore import log as logger
  24. from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
  25. config_get_set_num_parallel_workers
  26. DATA_DIR = "../data/dataset/testCifar10Data"
  27. DATA_DIR2 = "../data/dataset/testImageNetData2/train/"
  28. GENERATE_GOLDEN = False
  29. def test_mixup_batch_success1(plot=False):
  30. """
  31. Test MixUpBatch op with specified alpha parameter
  32. """
  33. logger.info("test_mixup_batch_success1")
  34. # Original Images
  35. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  36. ds_original = ds_original.batch(5, drop_remainder=True)
  37. images_original = None
  38. for idx, (image, _) in enumerate(ds_original):
  39. if idx == 0:
  40. images_original = image
  41. else:
  42. images_original = np.append(images_original, image, axis=0)
  43. # MixUp Images
  44. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  45. one_hot_op = data_trans.OneHot(num_classes=10)
  46. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  47. mixup_batch_op = vision.MixUpBatch(2)
  48. data1 = data1.batch(5, drop_remainder=True)
  49. data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op)
  50. images_mixup = None
  51. for idx, (image, _) in enumerate(data1):
  52. if idx == 0:
  53. images_mixup = image
  54. else:
  55. images_mixup = np.append(images_mixup, image, axis=0)
  56. if plot:
  57. visualize_list(images_original, images_mixup)
  58. num_samples = images_original.shape[0]
  59. mse = np.zeros(num_samples)
  60. for i in range(num_samples):
  61. mse[i] = diff_mse(images_mixup[i], images_original[i])
  62. logger.info("MSE= {}".format(str(np.mean(mse))))
  63. def test_mixup_batch_success2(plot=False):
  64. """
  65. Test MixUpBatch op with specified alpha parameter on ImageFolderDatasetV2
  66. """
  67. logger.info("test_mixup_batch_success2")
  68. # Original Images
  69. ds_original = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False)
  70. decode_op = vision.Decode()
  71. ds_original = ds_original.map(input_columns=["image"], operations=[decode_op])
  72. ds_original = ds_original.batch(4, pad_info={}, drop_remainder=True)
  73. images_original = None
  74. for idx, (image, _) in enumerate(ds_original):
  75. if idx == 0:
  76. images_original = image
  77. else:
  78. images_original = np.append(images_original, image, axis=0)
  79. # MixUp Images
  80. data1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False)
  81. decode_op = vision.Decode()
  82. data1 = data1.map(input_columns=["image"], operations=[decode_op])
  83. one_hot_op = data_trans.OneHot(num_classes=10)
  84. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  85. mixup_batch_op = vision.MixUpBatch(2.0)
  86. data1 = data1.batch(4, pad_info={}, drop_remainder=True)
  87. data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op)
  88. images_mixup = None
  89. for idx, (image, _) in enumerate(data1):
  90. if idx == 0:
  91. images_mixup = image
  92. else:
  93. images_mixup = np.append(images_mixup, image, axis=0)
  94. if plot:
  95. visualize_list(images_original, images_mixup)
  96. num_samples = images_original.shape[0]
  97. mse = np.zeros(num_samples)
  98. for i in range(num_samples):
  99. mse[i] = diff_mse(images_mixup[i], images_original[i])
  100. logger.info("MSE= {}".format(str(np.mean(mse))))
  101. def test_mixup_batch_success3(plot=False):
  102. """
  103. Test MixUpBatch op without specified alpha parameter.
  104. Alpha parameter will be selected by default in this case
  105. """
  106. logger.info("test_mixup_batch_success3")
  107. # Original Images
  108. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  109. ds_original = ds_original.batch(5, drop_remainder=True)
  110. images_original = None
  111. for idx, (image, _) in enumerate(ds_original):
  112. if idx == 0:
  113. images_original = image
  114. else:
  115. images_original = np.append(images_original, image, axis=0)
  116. # MixUp Images
  117. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  118. one_hot_op = data_trans.OneHot(num_classes=10)
  119. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  120. mixup_batch_op = vision.MixUpBatch()
  121. data1 = data1.batch(5, drop_remainder=True)
  122. data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op)
  123. images_mixup = np.array([])
  124. for idx, (image, _) in enumerate(data1):
  125. if idx == 0:
  126. images_mixup = image
  127. else:
  128. images_mixup = np.append(images_mixup, image, axis=0)
  129. if plot:
  130. visualize_list(images_original, images_mixup)
  131. num_samples = images_original.shape[0]
  132. mse = np.zeros(num_samples)
  133. for i in range(num_samples):
  134. mse[i] = diff_mse(images_mixup[i], images_original[i])
  135. logger.info("MSE= {}".format(str(np.mean(mse))))
  136. def test_mixup_batch_md5():
  137. """
  138. Test MixUpBatch with MD5:
  139. """
  140. logger.info("test_mixup_batch_md5")
  141. original_seed = config_get_set_seed(0)
  142. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  143. # MixUp Images
  144. data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  145. one_hot_op = data_trans.OneHot(num_classes=10)
  146. data = data.map(input_columns=["label"], operations=one_hot_op)
  147. mixup_batch_op = vision.MixUpBatch()
  148. data = data.batch(5, drop_remainder=True)
  149. data = data.map(input_columns=["image", "label"], operations=mixup_batch_op)
  150. filename = "mixup_batch_c_result.npz"
  151. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  152. # Restore config setting
  153. ds.config.set_seed(original_seed)
  154. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  155. def test_mixup_batch_fail1():
  156. """
  157. Test MixUpBatch Fail 1
  158. We expect this to fail because the images and labels are not batched
  159. """
  160. logger.info("test_mixup_batch_fail1")
  161. # Original Images
  162. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  163. ds_original = ds_original.batch(5)
  164. images_original = np.array([])
  165. for idx, (image, _) in enumerate(ds_original):
  166. if idx == 0:
  167. images_original = image
  168. else:
  169. images_original = np.append(images_original, image, axis=0)
  170. # MixUp Images
  171. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  172. one_hot_op = data_trans.OneHot(num_classes=10)
  173. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  174. mixup_batch_op = vision.MixUpBatch(0.1)
  175. with pytest.raises(RuntimeError) as error:
  176. data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op)
  177. for idx, (image, _) in enumerate(data1):
  178. if idx == 0:
  179. images_mixup = image
  180. else:
  181. images_mixup = np.append(images_mixup, image, axis=0)
  182. error_message = "You must make sure images are HWC or CHW and batch"
  183. assert error_message in str(error.value)
  184. def test_mixup_batch_fail2():
  185. """
  186. Test MixUpBatch Fail 2
  187. We expect this to fail because alpha is negative
  188. """
  189. logger.info("test_mixup_batch_fail2")
  190. # Original Images
  191. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  192. ds_original = ds_original.batch(5)
  193. images_original = np.array([])
  194. for idx, (image, _) in enumerate(ds_original):
  195. if idx == 0:
  196. images_original = image
  197. else:
  198. images_original = np.append(images_original, image, axis=0)
  199. # MixUp Images
  200. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  201. one_hot_op = data_trans.OneHot(num_classes=10)
  202. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  203. with pytest.raises(ValueError) as error:
  204. vision.MixUpBatch(-1)
  205. error_message = "Input is not within the required interval"
  206. assert error_message in str(error.value)
  207. def test_mixup_batch_fail3():
  208. """
  209. Test MixUpBatch op
  210. We expect this to fail because label column is not passed to mixup_batch
  211. """
  212. logger.info("test_mixup_batch_fail3")
  213. # Original Images
  214. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  215. ds_original = ds_original.batch(5, drop_remainder=True)
  216. images_original = None
  217. for idx, (image, _) in enumerate(ds_original):
  218. if idx == 0:
  219. images_original = image
  220. else:
  221. images_original = np.append(images_original, image, axis=0)
  222. # MixUp Images
  223. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  224. one_hot_op = data_trans.OneHot(num_classes=10)
  225. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  226. mixup_batch_op = vision.MixUpBatch()
  227. data1 = data1.batch(5, drop_remainder=True)
  228. data1 = data1.map(input_columns=["image"], operations=mixup_batch_op)
  229. with pytest.raises(RuntimeError) as error:
  230. images_mixup = np.array([])
  231. for idx, (image, _) in enumerate(data1):
  232. if idx == 0:
  233. images_mixup = image
  234. else:
  235. images_mixup = np.append(images_mixup, image, axis=0)
  236. error_message = "Both images and labels columns are required"
  237. assert error_message in str(error.value)
  238. def test_mixup_batch_fail4():
  239. """
  240. Test MixUpBatch Fail 2
  241. We expect this to fail because alpha is zero
  242. """
  243. logger.info("test_mixup_batch_fail4")
  244. # Original Images
  245. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  246. ds_original = ds_original.batch(5)
  247. images_original = np.array([])
  248. for idx, (image, _) in enumerate(ds_original):
  249. if idx == 0:
  250. images_original = image
  251. else:
  252. images_original = np.append(images_original, image, axis=0)
  253. # MixUp Images
  254. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  255. one_hot_op = data_trans.OneHot(num_classes=10)
  256. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  257. with pytest.raises(ValueError) as error:
  258. vision.MixUpBatch(0.0)
  259. error_message = "Input is not within the required interval"
  260. assert error_message in str(error.value)
  261. if __name__ == "__main__":
  262. test_mixup_batch_success1(plot=True)
  263. test_mixup_batch_success2(plot=True)
  264. test_mixup_batch_success3(plot=True)
  265. test_mixup_batch_md5()
  266. test_mixup_batch_fail1()
  267. test_mixup_batch_fail2()
  268. test_mixup_batch_fail3()
  269. test_mixup_batch_fail4()