You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_mixup_op.py 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing the MixUpBatch op in DE
  17. """
  18. import numpy as np
  19. import pytest
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.vision.c_transforms as vision
  22. import mindspore.dataset.transforms.c_transforms as data_trans
  23. from mindspore import log as logger
  24. from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
  25. config_get_set_num_parallel_workers
  26. DATA_DIR = "../data/dataset/testCifar10Data"
  27. DATA_DIR2 = "../data/dataset/testImageNetData2/train/"
  28. DATA_DIR3 = "../data/dataset/testCelebAData/"
  29. GENERATE_GOLDEN = False
  30. def test_mixup_batch_success1(plot=False):
  31. """
  32. Test MixUpBatch op with specified alpha parameter
  33. """
  34. logger.info("test_mixup_batch_success1")
  35. # Original Images
  36. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  37. ds_original = ds_original.batch(5, drop_remainder=True)
  38. images_original = None
  39. for idx, (image, _) in enumerate(ds_original):
  40. if idx == 0:
  41. images_original = image.asnumpy()
  42. else:
  43. images_original = np.append(images_original, image.asnumpy(), axis=0)
  44. # MixUp Images
  45. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  46. one_hot_op = data_trans.OneHot(num_classes=10)
  47. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  48. mixup_batch_op = vision.MixUpBatch(2)
  49. data1 = data1.batch(5, drop_remainder=True)
  50. data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "label"])
  51. images_mixup = None
  52. for idx, (image, _) in enumerate(data1):
  53. if idx == 0:
  54. images_mixup = image.asnumpy()
  55. else:
  56. images_mixup = np.append(images_mixup, image.asnumpy(), axis=0)
  57. if plot:
  58. visualize_list(images_original, images_mixup)
  59. num_samples = images_original.shape[0]
  60. mse = np.zeros(num_samples)
  61. for i in range(num_samples):
  62. mse[i] = diff_mse(images_mixup[i], images_original[i])
  63. logger.info("MSE= {}".format(str(np.mean(mse))))
  64. def test_mixup_batch_success2(plot=False):
  65. """
  66. Test MixUpBatch op with specified alpha parameter on ImageFolderDataset
  67. """
  68. logger.info("test_mixup_batch_success2")
  69. # Original Images
  70. ds_original = ds.ImageFolderDataset(dataset_dir=DATA_DIR2, shuffle=False)
  71. decode_op = vision.Decode()
  72. ds_original = ds_original.map(operations=[decode_op], input_columns=["image"])
  73. ds_original = ds_original.batch(4, pad_info={}, drop_remainder=True)
  74. images_original = None
  75. for idx, (image, _) in enumerate(ds_original):
  76. if idx == 0:
  77. images_original = image.asnumpy()
  78. else:
  79. images_original = np.append(images_original, image.asnumpy(), axis=0)
  80. # MixUp Images
  81. data1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR2, shuffle=False)
  82. decode_op = vision.Decode()
  83. data1 = data1.map(operations=[decode_op], input_columns=["image"])
  84. one_hot_op = data_trans.OneHot(num_classes=10)
  85. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  86. mixup_batch_op = vision.MixUpBatch(2.0)
  87. data1 = data1.batch(4, pad_info={}, drop_remainder=True)
  88. data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "label"])
  89. images_mixup = None
  90. for idx, (image, _) in enumerate(data1):
  91. if idx == 0:
  92. images_mixup = image.asnumpy()
  93. else:
  94. images_mixup = np.append(images_mixup, image.asnumpy(), axis=0)
  95. if plot:
  96. visualize_list(images_original, images_mixup)
  97. num_samples = images_original.shape[0]
  98. mse = np.zeros(num_samples)
  99. for i in range(num_samples):
  100. mse[i] = diff_mse(images_mixup[i], images_original[i])
  101. logger.info("MSE= {}".format(str(np.mean(mse))))
  102. def test_mixup_batch_success3(plot=False):
  103. """
  104. Test MixUpBatch op without specified alpha parameter.
  105. Alpha parameter will be selected by default in this case
  106. """
  107. logger.info("test_mixup_batch_success3")
  108. # Original Images
  109. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  110. ds_original = ds_original.batch(5, drop_remainder=True)
  111. images_original = None
  112. for idx, (image, _) in enumerate(ds_original):
  113. if idx == 0:
  114. images_original = image.asnumpy()
  115. else:
  116. images_original = np.append(images_original, image.asnumpy(), axis=0)
  117. # MixUp Images
  118. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  119. one_hot_op = data_trans.OneHot(num_classes=10)
  120. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  121. mixup_batch_op = vision.MixUpBatch()
  122. data1 = data1.batch(5, drop_remainder=True)
  123. data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "label"])
  124. images_mixup = np.array([])
  125. for idx, (image, _) in enumerate(data1):
  126. if idx == 0:
  127. images_mixup = image.asnumpy()
  128. else:
  129. images_mixup = np.append(images_mixup, image.asnumpy(), axis=0)
  130. if plot:
  131. visualize_list(images_original, images_mixup)
  132. num_samples = images_original.shape[0]
  133. mse = np.zeros(num_samples)
  134. for i in range(num_samples):
  135. mse[i] = diff_mse(images_mixup[i], images_original[i])
  136. logger.info("MSE= {}".format(str(np.mean(mse))))
  137. def test_mixup_batch_success4(plot=False):
  138. """
  139. Test MixUpBatch op on a dataset where OneHot returns a 2D vector.
  140. Alpha parameter will be selected by default in this case
  141. """
  142. logger.info("test_mixup_batch_success4")
  143. # Original Images
  144. ds_original = ds.CelebADataset(DATA_DIR3, shuffle=False)
  145. decode_op = vision.Decode()
  146. ds_original = ds_original.map(operations=[decode_op], input_columns=["image"])
  147. ds_original = ds_original.batch(2, drop_remainder=True)
  148. images_original = None
  149. for idx, (image, _) in enumerate(ds_original):
  150. if idx == 0:
  151. images_original = image.asnumpy()
  152. else:
  153. images_original = np.append(images_original, image.asnumpy(), axis=0)
  154. # MixUp Images
  155. data1 = ds.CelebADataset(DATA_DIR3, shuffle=False)
  156. decode_op = vision.Decode()
  157. data1 = data1.map(operations=[decode_op], input_columns=["image"])
  158. one_hot_op = data_trans.OneHot(num_classes=100)
  159. data1 = data1.map(operations=one_hot_op, input_columns=["attr"])
  160. mixup_batch_op = vision.MixUpBatch()
  161. data1 = data1.batch(2, drop_remainder=True)
  162. data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "attr"])
  163. images_mixup = np.array([])
  164. for idx, (image, _) in enumerate(data1):
  165. if idx == 0:
  166. images_mixup = image.asnumpy()
  167. else:
  168. images_mixup = np.append(images_mixup, image.asnumpy(), axis=0)
  169. if plot:
  170. visualize_list(images_original, images_mixup)
  171. num_samples = images_original.shape[0]
  172. mse = np.zeros(num_samples)
  173. for i in range(num_samples):
  174. mse[i] = diff_mse(images_mixup[i], images_original[i])
  175. logger.info("MSE= {}".format(str(np.mean(mse))))
  176. def test_mixup_batch_md5():
  177. """
  178. Test MixUpBatch with MD5:
  179. """
  180. logger.info("test_mixup_batch_md5")
  181. original_seed = config_get_set_seed(0)
  182. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  183. # MixUp Images
  184. data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  185. one_hot_op = data_trans.OneHot(num_classes=10)
  186. data = data.map(operations=one_hot_op, input_columns=["label"])
  187. mixup_batch_op = vision.MixUpBatch()
  188. data = data.batch(5, drop_remainder=True)
  189. data = data.map(operations=mixup_batch_op, input_columns=["image", "label"])
  190. filename = "mixup_batch_c_result.npz"
  191. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  192. # Restore config setting
  193. ds.config.set_seed(original_seed)
  194. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  195. def test_mixup_batch_fail1():
  196. """
  197. Test MixUpBatch Fail 1
  198. We expect this to fail because the images and labels are not batched
  199. """
  200. logger.info("test_mixup_batch_fail1")
  201. # Original Images
  202. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  203. ds_original = ds_original.batch(5)
  204. images_original = np.array([])
  205. for idx, (image, _) in enumerate(ds_original):
  206. if idx == 0:
  207. images_original = image.asnumpy()
  208. else:
  209. images_original = np.append(images_original, image.asnumpy(), axis=0)
  210. # MixUp Images
  211. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  212. one_hot_op = data_trans.OneHot(num_classes=10)
  213. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  214. mixup_batch_op = vision.MixUpBatch(0.1)
  215. with pytest.raises(RuntimeError) as error:
  216. data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "label"])
  217. for idx, (image, _) in enumerate(data1):
  218. if idx == 0:
  219. images_mixup = image.asnumpy()
  220. else:
  221. images_mixup = np.append(images_mixup, image.asnumpy(), axis=0)
  222. error_message = "You must make sure images are HWC or CHW and batched"
  223. assert error_message in str(error.value)
  224. def test_mixup_batch_fail2():
  225. """
  226. Test MixUpBatch Fail 2
  227. We expect this to fail because alpha is negative
  228. """
  229. logger.info("test_mixup_batch_fail2")
  230. # Original Images
  231. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  232. ds_original = ds_original.batch(5)
  233. images_original = np.array([])
  234. for idx, (image, _) in enumerate(ds_original):
  235. if idx == 0:
  236. images_original = image.asnumpy()
  237. else:
  238. images_original = np.append(images_original, image.asnumpy(), axis=0)
  239. # MixUp Images
  240. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  241. one_hot_op = data_trans.OneHot(num_classes=10)
  242. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  243. with pytest.raises(ValueError) as error:
  244. vision.MixUpBatch(-1)
  245. error_message = "Input is not within the required interval"
  246. assert error_message in str(error.value)
  247. def test_mixup_batch_fail3():
  248. """
  249. Test MixUpBatch op
  250. We expect this to fail because label column is not passed to mixup_batch
  251. """
  252. logger.info("test_mixup_batch_fail3")
  253. # Original Images
  254. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  255. ds_original = ds_original.batch(5, drop_remainder=True)
  256. images_original = None
  257. for idx, (image, _) in enumerate(ds_original):
  258. if idx == 0:
  259. images_original = image.asnumpy()
  260. else:
  261. images_original = np.append(images_original, image.asnumpy(), axis=0)
  262. # MixUp Images
  263. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  264. one_hot_op = data_trans.OneHot(num_classes=10)
  265. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  266. mixup_batch_op = vision.MixUpBatch()
  267. data1 = data1.batch(5, drop_remainder=True)
  268. data1 = data1.map(operations=mixup_batch_op, input_columns=["image"])
  269. with pytest.raises(RuntimeError) as error:
  270. images_mixup = np.array([])
  271. for idx, (image, _) in enumerate(data1):
  272. if idx == 0:
  273. images_mixup = image.asnumpy()
  274. else:
  275. images_mixup = np.append(images_mixup, image.asnumpy(), axis=0)
  276. error_message = "Both images and labels columns are required"
  277. assert error_message in str(error.value)
  278. def test_mixup_batch_fail4():
  279. """
  280. Test MixUpBatch Fail 2
  281. We expect this to fail because alpha is zero
  282. """
  283. logger.info("test_mixup_batch_fail4")
  284. # Original Images
  285. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  286. ds_original = ds_original.batch(5)
  287. images_original = np.array([])
  288. for idx, (image, _) in enumerate(ds_original):
  289. if idx == 0:
  290. images_original = image.asnumpy()
  291. else:
  292. images_original = np.append(images_original, image.asnumpy(), axis=0)
  293. # MixUp Images
  294. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  295. one_hot_op = data_trans.OneHot(num_classes=10)
  296. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  297. with pytest.raises(ValueError) as error:
  298. vision.MixUpBatch(0.0)
  299. error_message = "Input is not within the required interval"
  300. assert error_message in str(error.value)
  301. def test_mixup_batch_fail5():
  302. """
  303. Test MixUpBatch Fail 5
  304. We expect this to fail because labels are not OntHot encoded
  305. """
  306. logger.info("test_mixup_batch_fail5")
  307. # Original Images
  308. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  309. ds_original = ds_original.batch(5)
  310. images_original = np.array([])
  311. for idx, (image, _) in enumerate(ds_original):
  312. if idx == 0:
  313. images_original = image.asnumpy()
  314. else:
  315. images_original = np.append(images_original, image.asnumpy(), axis=0)
  316. # MixUp Images
  317. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  318. mixup_batch_op = vision.MixUpBatch()
  319. data1 = data1.batch(5, drop_remainder=True)
  320. data1 = data1.map(operations=mixup_batch_op, input_columns=["image", "label"])
  321. with pytest.raises(RuntimeError) as error:
  322. images_mixup = np.array([])
  323. for idx, (image, _) in enumerate(data1):
  324. if idx == 0:
  325. images_mixup = image.asnumpy()
  326. else:
  327. images_mixup = np.append(images_mixup, image.asnumpy(), axis=0)
  328. error_message = "MixUpBatch: Wrong labels shape. The second column (labels) must have a shape of NC or NLC"
  329. assert error_message in str(error.value)
  330. if __name__ == "__main__":
  331. test_mixup_batch_success1(plot=True)
  332. test_mixup_batch_success2(plot=True)
  333. test_mixup_batch_success3(plot=True)
  334. test_mixup_batch_success4(plot=True)
  335. test_mixup_batch_md5()
  336. test_mixup_batch_fail1()
  337. test_mixup_batch_fail2()
  338. test_mixup_batch_fail3()
  339. test_mixup_batch_fail4()
  340. test_mixup_batch_fail5()