You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cutmix_batch_op.py 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing the CutMixBatch op in DE
  17. """
  18. import numpy as np
  19. import pytest
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.vision.c_transforms as vision
  22. import mindspore.dataset.transforms.c_transforms as data_trans
  23. import mindspore.dataset.vision.utils as mode
  24. from mindspore import log as logger
  25. from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
  26. config_get_set_num_parallel_workers
  27. DATA_DIR = "../data/dataset/testCifar10Data"
  28. DATA_DIR2 = "../data/dataset/testImageNetData2/train/"
  29. DATA_DIR3 = "../data/dataset/testCelebAData/"
  30. GENERATE_GOLDEN = False
  31. def test_cutmix_batch_success1(plot=False):
  32. """
  33. Test CutMixBatch op with specified alpha and prob parameters on a batch of CHW images
  34. """
  35. logger.info("test_cutmix_batch_success1")
  36. # Original Images
  37. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  38. ds_original = ds_original.batch(5, drop_remainder=True)
  39. images_original = None
  40. for idx, (image, _) in enumerate(ds_original):
  41. if idx == 0:
  42. images_original = image.asnumpy()
  43. else:
  44. images_original = np.append(images_original, image.asnumpy(), axis=0)
  45. # CutMix Images
  46. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  47. hwc2chw_op = vision.HWC2CHW()
  48. data1 = data1.map(operations=hwc2chw_op, input_columns=["image"])
  49. one_hot_op = data_trans.OneHot(num_classes=10)
  50. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  51. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW, 2.0, 0.5)
  52. data1 = data1.batch(5, drop_remainder=True)
  53. data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
  54. images_cutmix = None
  55. for idx, (image, _) in enumerate(data1):
  56. if idx == 0:
  57. images_cutmix = image.asnumpy().transpose(0, 2, 3, 1)
  58. else:
  59. images_cutmix = np.append(images_cutmix, image.asnumpy().transpose(0, 2, 3, 1), axis=0)
  60. if plot:
  61. visualize_list(images_original, images_cutmix)
  62. num_samples = images_original.shape[0]
  63. mse = np.zeros(num_samples)
  64. for i in range(num_samples):
  65. mse[i] = diff_mse(images_cutmix[i], images_original[i])
  66. logger.info("MSE= {}".format(str(np.mean(mse))))
  67. def test_cutmix_batch_success2(plot=False):
  68. """
  69. Test CutMixBatch op with default values for alpha and prob on a batch of rescaled HWC images
  70. """
  71. logger.info("test_cutmix_batch_success2")
  72. # Original Images
  73. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  74. ds_original = ds_original.batch(5, drop_remainder=True)
  75. images_original = None
  76. for idx, (image, _) in enumerate(ds_original):
  77. if idx == 0:
  78. images_original = image.asnumpy()
  79. else:
  80. images_original = np.append(images_original, image.asnumpy(), axis=0)
  81. # CutMix Images
  82. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  83. one_hot_op = data_trans.OneHot(num_classes=10)
  84. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  85. rescale_op = vision.Rescale((1.0 / 255.0), 0.0)
  86. data1 = data1.map(operations=rescale_op, input_columns=["image"])
  87. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  88. data1 = data1.batch(5, drop_remainder=True)
  89. data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
  90. images_cutmix = None
  91. for idx, (image, _) in enumerate(data1):
  92. if idx == 0:
  93. images_cutmix = image.asnumpy()
  94. else:
  95. images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
  96. if plot:
  97. visualize_list(images_original, images_cutmix)
  98. num_samples = images_original.shape[0]
  99. mse = np.zeros(num_samples)
  100. for i in range(num_samples):
  101. mse[i] = diff_mse(images_cutmix[i], images_original[i])
  102. logger.info("MSE= {}".format(str(np.mean(mse))))
  103. def test_cutmix_batch_success3(plot=False):
  104. """
  105. Test CutMixBatch op with default values for alpha and prob on a batch of HWC images on ImageFolderDataset
  106. """
  107. logger.info("test_cutmix_batch_success3")
  108. ds_original = ds.ImageFolderDataset(dataset_dir=DATA_DIR2, shuffle=False)
  109. decode_op = vision.Decode()
  110. ds_original = ds_original.map(operations=[decode_op], input_columns=["image"])
  111. resize_op = vision.Resize([224, 224])
  112. ds_original = ds_original.map(operations=[resize_op], input_columns=["image"])
  113. ds_original = ds_original.batch(4, pad_info={}, drop_remainder=True)
  114. images_original = None
  115. for idx, (image, _) in enumerate(ds_original):
  116. if idx == 0:
  117. images_original = image.asnumpy()
  118. else:
  119. images_original = np.append(images_original, image.asnumpy(), axis=0)
  120. # CutMix Images
  121. data1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR2, shuffle=False)
  122. decode_op = vision.Decode()
  123. data1 = data1.map(operations=[decode_op], input_columns=["image"])
  124. resize_op = vision.Resize([224, 224])
  125. data1 = data1.map(operations=[resize_op], input_columns=["image"])
  126. one_hot_op = data_trans.OneHot(num_classes=10)
  127. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  128. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  129. data1 = data1.batch(4, pad_info={}, drop_remainder=True)
  130. data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
  131. images_cutmix = None
  132. for idx, (image, _) in enumerate(data1):
  133. if idx == 0:
  134. images_cutmix = image.asnumpy()
  135. else:
  136. images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
  137. if plot:
  138. visualize_list(images_original, images_cutmix)
  139. num_samples = images_original.shape[0]
  140. mse = np.zeros(num_samples)
  141. for i in range(num_samples):
  142. mse[i] = diff_mse(images_cutmix[i], images_original[i])
  143. logger.info("MSE= {}".format(str(np.mean(mse))))
  144. def test_cutmix_batch_success4(plot=False):
  145. """
  146. Test CutMixBatch on a dataset where OneHot returns a 2D vector
  147. """
  148. logger.info("test_cutmix_batch_success4")
  149. ds_original = ds.CelebADataset(DATA_DIR3, shuffle=False)
  150. decode_op = vision.Decode()
  151. ds_original = ds_original.map(operations=[decode_op], input_columns=["image"])
  152. resize_op = vision.Resize([224, 224])
  153. ds_original = ds_original.map(operations=[resize_op], input_columns=["image"])
  154. ds_original = ds_original.batch(2, drop_remainder=True)
  155. images_original = None
  156. for idx, (image, _) in enumerate(ds_original):
  157. if idx == 0:
  158. images_original = image.asnumpy()
  159. else:
  160. images_original = np.append(images_original, image.asnumpy(), axis=0)
  161. # CutMix Images
  162. data1 = ds.CelebADataset(dataset_dir=DATA_DIR3, shuffle=False)
  163. decode_op = vision.Decode()
  164. data1 = data1.map(operations=[decode_op], input_columns=["image"])
  165. resize_op = vision.Resize([224, 224])
  166. data1 = data1.map(operations=[resize_op], input_columns=["image"])
  167. one_hot_op = data_trans.OneHot(num_classes=100)
  168. data1 = data1.map(operations=one_hot_op, input_columns=["attr"])
  169. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.5, 0.9)
  170. data1 = data1.batch(2, drop_remainder=True)
  171. data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "attr"])
  172. images_cutmix = None
  173. for idx, (image, _) in enumerate(data1):
  174. if idx == 0:
  175. images_cutmix = image.asnumpy()
  176. else:
  177. images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
  178. if plot:
  179. visualize_list(images_original, images_cutmix)
  180. num_samples = images_original.shape[0]
  181. mse = np.zeros(num_samples)
  182. for i in range(num_samples):
  183. mse[i] = diff_mse(images_cutmix[i], images_original[i])
  184. logger.info("MSE= {}".format(str(np.mean(mse))))
  185. def test_cutmix_batch_nhwc_md5():
  186. """
  187. Test CutMixBatch on a batch of HWC images with MD5:
  188. """
  189. logger.info("test_cutmix_batch_nhwc_md5")
  190. original_seed = config_get_set_seed(0)
  191. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  192. # CutMixBatch Images
  193. data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  194. one_hot_op = data_trans.OneHot(num_classes=10)
  195. data = data.map(operations=one_hot_op, input_columns=["label"])
  196. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  197. data = data.batch(5, drop_remainder=True)
  198. data = data.map(operations=cutmix_batch_op, input_columns=["image", "label"])
  199. filename = "cutmix_batch_c_nhwc_result.npz"
  200. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  201. # Restore config setting
  202. ds.config.set_seed(original_seed)
  203. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  204. def test_cutmix_batch_nchw_md5():
  205. """
  206. Test CutMixBatch on a batch of CHW images with MD5:
  207. """
  208. logger.info("test_cutmix_batch_nchw_md5")
  209. original_seed = config_get_set_seed(0)
  210. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  211. # CutMixBatch Images
  212. data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  213. hwc2chw_op = vision.HWC2CHW()
  214. data = data.map(operations=hwc2chw_op, input_columns=["image"])
  215. one_hot_op = data_trans.OneHot(num_classes=10)
  216. data = data.map(operations=one_hot_op, input_columns=["label"])
  217. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
  218. data = data.batch(5, drop_remainder=True)
  219. data = data.map(operations=cutmix_batch_op, input_columns=["image", "label"])
  220. filename = "cutmix_batch_c_nchw_result.npz"
  221. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  222. # Restore config setting
  223. ds.config.set_seed(original_seed)
  224. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  225. def test_cutmix_batch_fail1():
  226. """
  227. Test CutMixBatch Fail 1
  228. We expect this to fail because the images and labels are not batched
  229. """
  230. logger.info("test_cutmix_batch_fail1")
  231. # CutMixBatch Images
  232. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  233. one_hot_op = data_trans.OneHot(num_classes=10)
  234. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  235. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  236. with pytest.raises(RuntimeError) as error:
  237. data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
  238. for idx, (image, _) in enumerate(data1):
  239. if idx == 0:
  240. images_cutmix = image.asnumpy()
  241. else:
  242. images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
  243. error_message = "You must make sure images are HWC or CHW and batch "
  244. assert error_message in str(error.value)
  245. def test_cutmix_batch_fail2():
  246. """
  247. Test CutMixBatch Fail 2
  248. We expect this to fail because alpha is negative
  249. """
  250. logger.info("test_cutmix_batch_fail2")
  251. # CutMixBatch Images
  252. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  253. one_hot_op = data_trans.OneHot(num_classes=10)
  254. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  255. with pytest.raises(ValueError) as error:
  256. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, -1)
  257. error_message = "Input is not within the required interval"
  258. assert error_message in str(error.value)
  259. def test_cutmix_batch_fail3():
  260. """
  261. Test CutMixBatch Fail 2
  262. We expect this to fail because prob is larger than 1
  263. """
  264. logger.info("test_cutmix_batch_fail3")
  265. # CutMixBatch Images
  266. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  267. one_hot_op = data_trans.OneHot(num_classes=10)
  268. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  269. with pytest.raises(ValueError) as error:
  270. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, 2)
  271. error_message = "Input is not within the required interval"
  272. assert error_message in str(error.value)
  273. def test_cutmix_batch_fail4():
  274. """
  275. Test CutMixBatch Fail 2
  276. We expect this to fail because prob is negative
  277. """
  278. logger.info("test_cutmix_batch_fail4")
  279. # CutMixBatch Images
  280. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  281. one_hot_op = data_trans.OneHot(num_classes=10)
  282. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  283. with pytest.raises(ValueError) as error:
  284. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, -1)
  285. error_message = "Input is not within the required interval"
  286. assert error_message in str(error.value)
  287. def test_cutmix_batch_fail5():
  288. """
  289. Test CutMixBatch op
  290. We expect this to fail because label column is not passed to cutmix_batch
  291. """
  292. logger.info("test_cutmix_batch_fail5")
  293. # CutMixBatch Images
  294. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  295. one_hot_op = data_trans.OneHot(num_classes=10)
  296. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  297. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  298. data1 = data1.batch(5, drop_remainder=True)
  299. data1 = data1.map(operations=cutmix_batch_op, input_columns=["image"])
  300. with pytest.raises(RuntimeError) as error:
  301. images_cutmix = np.array([])
  302. for idx, (image, _) in enumerate(data1):
  303. if idx == 0:
  304. images_cutmix = image.asnumpy()
  305. else:
  306. images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
  307. error_message = "both image and label columns are required"
  308. assert error_message in str(error.value)
  309. def test_cutmix_batch_fail6():
  310. """
  311. Test CutMixBatch op
  312. We expect this to fail because image_batch_format passed to CutMixBatch doesn't match the format of the images
  313. """
  314. logger.info("test_cutmix_batch_fail6")
  315. # CutMixBatch Images
  316. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  317. one_hot_op = data_trans.OneHot(num_classes=10)
  318. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  319. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
  320. data1 = data1.batch(5, drop_remainder=True)
  321. data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
  322. with pytest.raises(RuntimeError) as error:
  323. images_cutmix = np.array([])
  324. for idx, (image, _) in enumerate(data1):
  325. if idx == 0:
  326. images_cutmix = image.asnumpy()
  327. else:
  328. images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
  329. error_message = "image doesn't match the NCHW format."
  330. assert error_message in str(error.value)
  331. def test_cutmix_batch_fail7():
  332. """
  333. Test CutMixBatch op
  334. We expect this to fail because labels are not in one-hot format
  335. """
  336. logger.info("test_cutmix_batch_fail7")
  337. # CutMixBatch Images
  338. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  339. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  340. data1 = data1.batch(5, drop_remainder=True)
  341. data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
  342. with pytest.raises(RuntimeError) as error:
  343. images_cutmix = np.array([])
  344. for idx, (image, _) in enumerate(data1):
  345. if idx == 0:
  346. images_cutmix = image.asnumpy()
  347. else:
  348. images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
  349. error_message = "wrong labels shape. The second column (labels) must have a shape of NC or NLC"
  350. assert error_message in str(error.value)
  351. def test_cutmix_batch_fail8():
  352. """
  353. Test CutMixBatch Fail 8
  354. We expect this to fail because alpha is zero
  355. """
  356. logger.info("test_cutmix_batch_fail8")
  357. # CutMixBatch Images
  358. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  359. one_hot_op = data_trans.OneHot(num_classes=10)
  360. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  361. with pytest.raises(ValueError) as error:
  362. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.0)
  363. error_message = "Input is not within the required interval"
  364. assert error_message in str(error.value)
  365. if __name__ == "__main__":
  366. test_cutmix_batch_success1(plot=True)
  367. test_cutmix_batch_success2(plot=True)
  368. test_cutmix_batch_success3(plot=True)
  369. test_cutmix_batch_success4(plot=True)
  370. test_cutmix_batch_nchw_md5()
  371. test_cutmix_batch_nhwc_md5()
  372. test_cutmix_batch_fail1()
  373. test_cutmix_batch_fail2()
  374. test_cutmix_batch_fail3()
  375. test_cutmix_batch_fail4()
  376. test_cutmix_batch_fail5()
  377. test_cutmix_batch_fail6()
  378. test_cutmix_batch_fail7()
  379. test_cutmix_batch_fail8()