You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cutmix_batch_op.py 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing the CutMixBatch op in DE
  17. """
  18. import numpy as np
  19. import pytest
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.vision.c_transforms as vision
  22. import mindspore.dataset.transforms.c_transforms as data_trans
  23. import mindspore.dataset.vision.utils as mode
  24. from mindspore import log as logger
  25. from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
  26. config_get_set_num_parallel_workers
  27. DATA_DIR = "../data/dataset/testCifar10Data"
  28. DATA_DIR2 = "../data/dataset/testImageNetData2/train/"
  29. DATA_DIR3 = "../data/dataset/testCelebAData/"
  30. GENERATE_GOLDEN = False
  31. def test_cutmix_batch_success1(plot=False):
  32. """
  33. Test CutMixBatch op with specified alpha and prob parameters on a batch of CHW images
  34. """
  35. logger.info("test_cutmix_batch_success1")
  36. # Original Images
  37. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  38. ds_original = ds_original.batch(5, drop_remainder=True)
  39. images_original = None
  40. for idx, (image, _) in enumerate(ds_original):
  41. if idx == 0:
  42. images_original = image.asnumpy()
  43. else:
  44. images_original = np.append(images_original, image.asnumpy(), axis=0)
  45. # CutMix Images
  46. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  47. hwc2chw_op = vision.HWC2CHW()
  48. data1 = data1.map(operations=hwc2chw_op, input_columns=["image"])
  49. one_hot_op = data_trans.OneHot(num_classes=10)
  50. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  51. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW, 2.0, 0.5)
  52. data1 = data1.batch(5, drop_remainder=True)
  53. data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
  54. images_cutmix = None
  55. for idx, (image, _) in enumerate(data1):
  56. if idx == 0:
  57. images_cutmix = image.asnumpy().transpose(0, 2, 3, 1)
  58. else:
  59. images_cutmix = np.append(images_cutmix, image.asnumpy().transpose(0, 2, 3, 1), axis=0)
  60. if plot:
  61. visualize_list(images_original, images_cutmix)
  62. num_samples = images_original.shape[0]
  63. mse = np.zeros(num_samples)
  64. for i in range(num_samples):
  65. mse[i] = diff_mse(images_cutmix[i], images_original[i])
  66. logger.info("MSE= {}".format(str(np.mean(mse))))
  67. def test_cutmix_batch_success2(plot=False):
  68. """
  69. Test CutMixBatch op with default values for alpha and prob on a batch of rescaled HWC images
  70. """
  71. logger.info("test_cutmix_batch_success2")
  72. # Original Images
  73. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  74. ds_original = ds_original.batch(5, drop_remainder=True)
  75. images_original = None
  76. for idx, (image, _) in enumerate(ds_original):
  77. if idx == 0:
  78. images_original = image.asnumpy()
  79. else:
  80. images_original = np.append(images_original, image.asnumpy(), axis=0)
  81. # CutMix Images
  82. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  83. one_hot_op = data_trans.OneHot(num_classes=10)
  84. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  85. rescale_op = vision.Rescale((1.0 / 255.0), 0.0)
  86. data1 = data1.map(operations=rescale_op, input_columns=["image"])
  87. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  88. data1 = data1.batch(5, drop_remainder=True)
  89. data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
  90. images_cutmix = None
  91. for idx, (image, _) in enumerate(data1):
  92. if idx == 0:
  93. images_cutmix = image.asnumpy()
  94. else:
  95. images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
  96. if plot:
  97. visualize_list(images_original, images_cutmix)
  98. num_samples = images_original.shape[0]
  99. mse = np.zeros(num_samples)
  100. for i in range(num_samples):
  101. mse[i] = diff_mse(images_cutmix[i], images_original[i])
  102. logger.info("MSE= {}".format(str(np.mean(mse))))
  103. def test_cutmix_batch_success3(plot=False):
  104. """
  105. Test CutMixBatch op with default values for alpha and prob on a batch of HWC images on ImageFolderDataset
  106. """
  107. logger.info("test_cutmix_batch_success3")
  108. ds_original = ds.ImageFolderDataset(dataset_dir=DATA_DIR2, shuffle=False)
  109. decode_op = vision.Decode()
  110. ds_original = ds_original.map(operations=[decode_op], input_columns=["image"])
  111. ds_original = ds_original.batch(4, pad_info={}, drop_remainder=True)
  112. images_original = None
  113. for idx, (image, _) in enumerate(ds_original):
  114. if idx == 0:
  115. images_original = image.asnumpy()
  116. else:
  117. images_original = np.append(images_original, image.asnumpy(), axis=0)
  118. # CutMix Images
  119. data1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR2, shuffle=False)
  120. decode_op = vision.Decode()
  121. data1 = data1.map(operations=[decode_op], input_columns=["image"])
  122. one_hot_op = data_trans.OneHot(num_classes=10)
  123. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  124. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  125. data1 = data1.batch(4, pad_info={}, drop_remainder=True)
  126. data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
  127. images_cutmix = None
  128. for idx, (image, _) in enumerate(data1):
  129. if idx == 0:
  130. images_cutmix = image.asnumpy()
  131. else:
  132. images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
  133. if plot:
  134. visualize_list(images_original, images_cutmix)
  135. num_samples = images_original.shape[0]
  136. mse = np.zeros(num_samples)
  137. for i in range(num_samples):
  138. mse[i] = diff_mse(images_cutmix[i], images_original[i])
  139. logger.info("MSE= {}".format(str(np.mean(mse))))
  140. def test_cutmix_batch_success4(plot=False):
  141. """
  142. Test CutMixBatch on a dataset where OneHot returns a 2D vector
  143. """
  144. logger.info("test_cutmix_batch_success4")
  145. ds_original = ds.CelebADataset(DATA_DIR3, shuffle=False)
  146. decode_op = vision.Decode()
  147. ds_original = ds_original.map(operations=[decode_op], input_columns=["image"])
  148. ds_original = ds_original.batch(2, drop_remainder=True)
  149. images_original = None
  150. for idx, (image, _) in enumerate(ds_original):
  151. if idx == 0:
  152. images_original = image.asnumpy()
  153. else:
  154. images_original = np.append(images_original, image.asnumpy(), axis=0)
  155. # CutMix Images
  156. data1 = ds.CelebADataset(dataset_dir=DATA_DIR3, shuffle=False)
  157. decode_op = vision.Decode()
  158. data1 = data1.map(operations=[decode_op], input_columns=["image"])
  159. one_hot_op = data_trans.OneHot(num_classes=100)
  160. data1 = data1.map(operations=one_hot_op, input_columns=["attr"])
  161. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.5, 0.9)
  162. data1 = data1.batch(2, drop_remainder=True)
  163. data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "attr"])
  164. images_cutmix = None
  165. for idx, (image, _) in enumerate(data1):
  166. if idx == 0:
  167. images_cutmix = image.asnumpy()
  168. else:
  169. images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
  170. if plot:
  171. visualize_list(images_original, images_cutmix)
  172. num_samples = images_original.shape[0]
  173. mse = np.zeros(num_samples)
  174. for i in range(num_samples):
  175. mse[i] = diff_mse(images_cutmix[i], images_original[i])
  176. logger.info("MSE= {}".format(str(np.mean(mse))))
  177. def test_cutmix_batch_nhwc_md5():
  178. """
  179. Test CutMixBatch on a batch of HWC images with MD5:
  180. """
  181. logger.info("test_cutmix_batch_nhwc_md5")
  182. original_seed = config_get_set_seed(0)
  183. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  184. # CutMixBatch Images
  185. data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  186. one_hot_op = data_trans.OneHot(num_classes=10)
  187. data = data.map(operations=one_hot_op, input_columns=["label"])
  188. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  189. data = data.batch(5, drop_remainder=True)
  190. data = data.map(operations=cutmix_batch_op, input_columns=["image", "label"])
  191. filename = "cutmix_batch_c_nhwc_result.npz"
  192. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  193. # Restore config setting
  194. ds.config.set_seed(original_seed)
  195. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  196. def test_cutmix_batch_nchw_md5():
  197. """
  198. Test CutMixBatch on a batch of CHW images with MD5:
  199. """
  200. logger.info("test_cutmix_batch_nchw_md5")
  201. original_seed = config_get_set_seed(0)
  202. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  203. # CutMixBatch Images
  204. data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  205. hwc2chw_op = vision.HWC2CHW()
  206. data = data.map(operations=hwc2chw_op, input_columns=["image"])
  207. one_hot_op = data_trans.OneHot(num_classes=10)
  208. data = data.map(operations=one_hot_op, input_columns=["label"])
  209. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
  210. data = data.batch(5, drop_remainder=True)
  211. data = data.map(operations=cutmix_batch_op, input_columns=["image", "label"])
  212. filename = "cutmix_batch_c_nchw_result.npz"
  213. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  214. # Restore config setting
  215. ds.config.set_seed(original_seed)
  216. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  217. def test_cutmix_batch_fail1():
  218. """
  219. Test CutMixBatch Fail 1
  220. We expect this to fail because the images and labels are not batched
  221. """
  222. logger.info("test_cutmix_batch_fail1")
  223. # CutMixBatch Images
  224. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  225. one_hot_op = data_trans.OneHot(num_classes=10)
  226. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  227. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  228. with pytest.raises(RuntimeError) as error:
  229. data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
  230. for idx, (image, _) in enumerate(data1):
  231. if idx == 0:
  232. images_cutmix = image.asnumpy()
  233. else:
  234. images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
  235. error_message = "You must make sure images are HWC or CHW and batch "
  236. assert error_message in str(error.value)
  237. def test_cutmix_batch_fail2():
  238. """
  239. Test CutMixBatch Fail 2
  240. We expect this to fail because alpha is negative
  241. """
  242. logger.info("test_cutmix_batch_fail2")
  243. # CutMixBatch Images
  244. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  245. one_hot_op = data_trans.OneHot(num_classes=10)
  246. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  247. with pytest.raises(ValueError) as error:
  248. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, -1)
  249. error_message = "Input is not within the required interval"
  250. assert error_message in str(error.value)
  251. def test_cutmix_batch_fail3():
  252. """
  253. Test CutMixBatch Fail 2
  254. We expect this to fail because prob is larger than 1
  255. """
  256. logger.info("test_cutmix_batch_fail3")
  257. # CutMixBatch Images
  258. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  259. one_hot_op = data_trans.OneHot(num_classes=10)
  260. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  261. with pytest.raises(ValueError) as error:
  262. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, 2)
  263. error_message = "Input is not within the required interval"
  264. assert error_message in str(error.value)
  265. def test_cutmix_batch_fail4():
  266. """
  267. Test CutMixBatch Fail 2
  268. We expect this to fail because prob is negative
  269. """
  270. logger.info("test_cutmix_batch_fail4")
  271. # CutMixBatch Images
  272. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  273. one_hot_op = data_trans.OneHot(num_classes=10)
  274. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  275. with pytest.raises(ValueError) as error:
  276. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, -1)
  277. error_message = "Input is not within the required interval"
  278. assert error_message in str(error.value)
  279. def test_cutmix_batch_fail5():
  280. """
  281. Test CutMixBatch op
  282. We expect this to fail because label column is not passed to cutmix_batch
  283. """
  284. logger.info("test_cutmix_batch_fail5")
  285. # CutMixBatch Images
  286. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  287. one_hot_op = data_trans.OneHot(num_classes=10)
  288. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  289. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  290. data1 = data1.batch(5, drop_remainder=True)
  291. data1 = data1.map(operations=cutmix_batch_op, input_columns=["image"])
  292. with pytest.raises(RuntimeError) as error:
  293. images_cutmix = np.array([])
  294. for idx, (image, _) in enumerate(data1):
  295. if idx == 0:
  296. images_cutmix = image.asnumpy()
  297. else:
  298. images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
  299. error_message = "Both images and labels columns are required"
  300. assert error_message in str(error.value)
  301. def test_cutmix_batch_fail6():
  302. """
  303. Test CutMixBatch op
  304. We expect this to fail because image_batch_format passed to CutMixBatch doesn't match the format of the images
  305. """
  306. logger.info("test_cutmix_batch_fail6")
  307. # CutMixBatch Images
  308. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  309. one_hot_op = data_trans.OneHot(num_classes=10)
  310. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  311. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
  312. data1 = data1.batch(5, drop_remainder=True)
  313. data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
  314. with pytest.raises(RuntimeError) as error:
  315. images_cutmix = np.array([])
  316. for idx, (image, _) in enumerate(data1):
  317. if idx == 0:
  318. images_cutmix = image.asnumpy()
  319. else:
  320. images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
  321. error_message = "CutMixBatch: Image doesn't match the given image format."
  322. assert error_message in str(error.value)
  323. def test_cutmix_batch_fail7():
  324. """
  325. Test CutMixBatch op
  326. We expect this to fail because labels are not in one-hot format
  327. """
  328. logger.info("test_cutmix_batch_fail7")
  329. # CutMixBatch Images
  330. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  331. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  332. data1 = data1.batch(5, drop_remainder=True)
  333. data1 = data1.map(operations=cutmix_batch_op, input_columns=["image", "label"])
  334. with pytest.raises(RuntimeError) as error:
  335. images_cutmix = np.array([])
  336. for idx, (image, _) in enumerate(data1):
  337. if idx == 0:
  338. images_cutmix = image.asnumpy()
  339. else:
  340. images_cutmix = np.append(images_cutmix, image.asnumpy(), axis=0)
  341. error_message = "CutMixBatch: Wrong labels shape. The second column (labels) must have a shape of NC or NLC"
  342. assert error_message in str(error.value)
  343. def test_cutmix_batch_fail8():
  344. """
  345. Test CutMixBatch Fail 8
  346. We expect this to fail because alpha is zero
  347. """
  348. logger.info("test_cutmix_batch_fail8")
  349. # CutMixBatch Images
  350. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  351. one_hot_op = data_trans.OneHot(num_classes=10)
  352. data1 = data1.map(operations=one_hot_op, input_columns=["label"])
  353. with pytest.raises(ValueError) as error:
  354. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.0)
  355. error_message = "Input is not within the required interval"
  356. assert error_message in str(error.value)
  357. if __name__ == "__main__":
  358. test_cutmix_batch_success1(plot=True)
  359. test_cutmix_batch_success2(plot=True)
  360. test_cutmix_batch_success3(plot=True)
  361. test_cutmix_batch_success4(plot=True)
  362. test_cutmix_batch_nchw_md5()
  363. test_cutmix_batch_nhwc_md5()
  364. test_cutmix_batch_fail1()
  365. test_cutmix_batch_fail2()
  366. test_cutmix_batch_fail3()
  367. test_cutmix_batch_fail4()
  368. test_cutmix_batch_fail5()
  369. test_cutmix_batch_fail6()
  370. test_cutmix_batch_fail7()
  371. test_cutmix_batch_fail8()