You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cutmix_batch_op.py 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing the CutMixBatch op in DE
  17. """
  18. import numpy as np
  19. import pytest
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.transforms.vision.c_transforms as vision
  22. import mindspore.dataset.transforms.c_transforms as data_trans
  23. import mindspore.dataset.transforms.vision.utils as mode
  24. from mindspore import log as logger
  25. from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
  26. config_get_set_num_parallel_workers
  27. DATA_DIR = "../data/dataset/testCifar10Data"
  28. DATA_DIR2 = "../data/dataset/testImageNetData2/train/"
  29. GENERATE_GOLDEN = False
  30. def test_cutmix_batch_success1(plot=False):
  31. """
  32. Test CutMixBatch op with specified alpha and prob parameters on a batch of CHW images
  33. """
  34. logger.info("test_cutmix_batch_success1")
  35. # Original Images
  36. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  37. ds_original = ds_original.batch(5, drop_remainder=True)
  38. images_original = None
  39. for idx, (image, _) in enumerate(ds_original):
  40. if idx == 0:
  41. images_original = image
  42. else:
  43. images_original = np.append(images_original, image, axis=0)
  44. # CutMix Images
  45. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  46. hwc2chw_op = vision.HWC2CHW()
  47. data1 = data1.map(input_columns=["image"], operations=hwc2chw_op)
  48. one_hot_op = data_trans.OneHot(num_classes=10)
  49. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  50. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW, 2.0, 0.5)
  51. data1 = data1.batch(5, drop_remainder=True)
  52. data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  53. images_cutmix = None
  54. for idx, (image, _) in enumerate(data1):
  55. if idx == 0:
  56. images_cutmix = image.transpose(0, 2, 3, 1)
  57. else:
  58. images_cutmix = np.append(images_cutmix, image.transpose(0, 2, 3, 1), axis=0)
  59. if plot:
  60. visualize_list(images_original, images_cutmix)
  61. num_samples = images_original.shape[0]
  62. mse = np.zeros(num_samples)
  63. for i in range(num_samples):
  64. mse[i] = diff_mse(images_cutmix[i], images_original[i])
  65. logger.info("MSE= {}".format(str(np.mean(mse))))
  66. def test_cutmix_batch_success2(plot=False):
  67. """
  68. Test CutMixBatch op with default values for alpha and prob on a batch of HWC images
  69. """
  70. logger.info("test_cutmix_batch_success2")
  71. # Original Images
  72. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  73. ds_original = ds_original.batch(5, drop_remainder=True)
  74. images_original = None
  75. for idx, (image, _) in enumerate(ds_original):
  76. if idx == 0:
  77. images_original = image
  78. else:
  79. images_original = np.append(images_original, image, axis=0)
  80. # CutMix Images
  81. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  82. one_hot_op = data_trans.OneHot(num_classes=10)
  83. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  84. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  85. data1 = data1.batch(5, drop_remainder=True)
  86. data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  87. images_cutmix = None
  88. for idx, (image, _) in enumerate(data1):
  89. if idx == 0:
  90. images_cutmix = image
  91. else:
  92. images_cutmix = np.append(images_cutmix, image, axis=0)
  93. if plot:
  94. visualize_list(images_original, images_cutmix)
  95. num_samples = images_original.shape[0]
  96. mse = np.zeros(num_samples)
  97. for i in range(num_samples):
  98. mse[i] = diff_mse(images_cutmix[i], images_original[i])
  99. logger.info("MSE= {}".format(str(np.mean(mse))))
  100. def test_cutmix_batch_success3(plot=False):
  101. """
  102. Test CutMixBatch op with default values for alpha and prob on a batch of HWC images on ImageFolderDatasetV2
  103. """
  104. logger.info("test_cutmix_batch_success3")
  105. ds_original = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False)
  106. decode_op = vision.Decode()
  107. ds_original = ds_original.map(input_columns=["image"], operations=[decode_op])
  108. ds_original = ds_original.batch(4, pad_info={}, drop_remainder=True)
  109. images_original = None
  110. for idx, (image, _) in enumerate(ds_original):
  111. if idx == 0:
  112. images_original = image
  113. else:
  114. images_original = np.append(images_original, image, axis=0)
  115. # CutMix Images
  116. data1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False)
  117. decode_op = vision.Decode()
  118. data1 = data1.map(input_columns=["image"], operations=[decode_op])
  119. one_hot_op = data_trans.OneHot(num_classes=10)
  120. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  121. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  122. data1 = data1.batch(4, pad_info={}, drop_remainder=True)
  123. data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  124. images_cutmix = None
  125. for idx, (image, _) in enumerate(data1):
  126. if idx == 0:
  127. images_cutmix = image
  128. else:
  129. images_cutmix = np.append(images_cutmix, image, axis=0)
  130. if plot:
  131. visualize_list(images_original, images_cutmix)
  132. num_samples = images_original.shape[0]
  133. mse = np.zeros(num_samples)
  134. for i in range(num_samples):
  135. mse[i] = diff_mse(images_cutmix[i], images_original[i])
  136. logger.info("MSE= {}".format(str(np.mean(mse))))
  137. def test_cutmix_batch_nhwc_md5():
  138. """
  139. Test CutMixBatch on a batch of HWC images with MD5:
  140. """
  141. logger.info("test_cutmix_batch_nhwc_md5")
  142. original_seed = config_get_set_seed(0)
  143. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  144. # CutMixBatch Images
  145. data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  146. one_hot_op = data_trans.OneHot(num_classes=10)
  147. data = data.map(input_columns=["label"], operations=one_hot_op)
  148. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  149. data = data.batch(5, drop_remainder=True)
  150. data = data.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  151. filename = "cutmix_batch_c_nhwc_result.npz"
  152. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  153. # Restore config setting
  154. ds.config.set_seed(original_seed)
  155. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  156. def test_cutmix_batch_nchw_md5():
  157. """
  158. Test CutMixBatch on a batch of CHW images with MD5:
  159. """
  160. logger.info("test_cutmix_batch_nchw_md5")
  161. original_seed = config_get_set_seed(0)
  162. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  163. # CutMixBatch Images
  164. data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  165. hwc2chw_op = vision.HWC2CHW()
  166. data = data.map(input_columns=["image"], operations=hwc2chw_op)
  167. one_hot_op = data_trans.OneHot(num_classes=10)
  168. data = data.map(input_columns=["label"], operations=one_hot_op)
  169. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
  170. data = data.batch(5, drop_remainder=True)
  171. data = data.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  172. filename = "cutmix_batch_c_nchw_result.npz"
  173. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  174. # Restore config setting
  175. ds.config.set_seed(original_seed)
  176. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  177. def test_cutmix_batch_fail1():
  178. """
  179. Test CutMixBatch Fail 1
  180. We expect this to fail because the images and labels are not batched
  181. """
  182. logger.info("test_cutmix_batch_fail1")
  183. # CutMixBatch Images
  184. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  185. one_hot_op = data_trans.OneHot(num_classes=10)
  186. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  187. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  188. with pytest.raises(RuntimeError) as error:
  189. data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  190. for idx, (image, _) in enumerate(data1):
  191. if idx == 0:
  192. images_cutmix = image
  193. else:
  194. images_cutmix = np.append(images_cutmix, image, axis=0)
  195. error_message = "You must make sure images are HWC or CHW and batch "
  196. assert error_message in str(error.value)
  197. def test_cutmix_batch_fail2():
  198. """
  199. Test CutMixBatch Fail 2
  200. We expect this to fail because alpha is negative
  201. """
  202. logger.info("test_cutmix_batch_fail2")
  203. # CutMixBatch Images
  204. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  205. one_hot_op = data_trans.OneHot(num_classes=10)
  206. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  207. with pytest.raises(ValueError) as error:
  208. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, -1)
  209. error_message = "Input is not within the required interval"
  210. assert error_message in str(error.value)
  211. def test_cutmix_batch_fail3():
  212. """
  213. Test CutMixBatch Fail 2
  214. We expect this to fail because prob is larger than 1
  215. """
  216. logger.info("test_cutmix_batch_fail3")
  217. # CutMixBatch Images
  218. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  219. one_hot_op = data_trans.OneHot(num_classes=10)
  220. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  221. with pytest.raises(ValueError) as error:
  222. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, 2)
  223. error_message = "Input is not within the required interval"
  224. assert error_message in str(error.value)
  225. def test_cutmix_batch_fail4():
  226. """
  227. Test CutMixBatch Fail 2
  228. We expect this to fail because prob is negative
  229. """
  230. logger.info("test_cutmix_batch_fail4")
  231. # CutMixBatch Images
  232. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  233. one_hot_op = data_trans.OneHot(num_classes=10)
  234. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  235. with pytest.raises(ValueError) as error:
  236. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, -1)
  237. error_message = "Input is not within the required interval"
  238. assert error_message in str(error.value)
  239. def test_cutmix_batch_fail5():
  240. """
  241. Test CutMixBatch op
  242. We expect this to fail because label column is not passed to cutmix_batch
  243. """
  244. logger.info("test_cutmix_batch_fail5")
  245. # CutMixBatch Images
  246. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  247. one_hot_op = data_trans.OneHot(num_classes=10)
  248. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  249. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  250. data1 = data1.batch(5, drop_remainder=True)
  251. data1 = data1.map(input_columns=["image"], operations=cutmix_batch_op)
  252. with pytest.raises(RuntimeError) as error:
  253. images_cutmix = np.array([])
  254. for idx, (image, _) in enumerate(data1):
  255. if idx == 0:
  256. images_cutmix = image
  257. else:
  258. images_cutmix = np.append(images_cutmix, image, axis=0)
  259. error_message = "Both images and labels columns are required"
  260. assert error_message in str(error.value)
  261. def test_cutmix_batch_fail6():
  262. """
  263. Test CutMixBatch op
  264. We expect this to fail because image_batch_format passed to CutMixBatch doesn't match the format of the images
  265. """
  266. logger.info("test_cutmix_batch_fail6")
  267. # CutMixBatch Images
  268. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  269. one_hot_op = data_trans.OneHot(num_classes=10)
  270. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  271. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
  272. data1 = data1.batch(5, drop_remainder=True)
  273. data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  274. with pytest.raises(RuntimeError) as error:
  275. images_cutmix = np.array([])
  276. for idx, (image, _) in enumerate(data1):
  277. if idx == 0:
  278. images_cutmix = image
  279. else:
  280. images_cutmix = np.append(images_cutmix, image, axis=0)
  281. error_message = "CutMixBatch: Image doesn't match the given image format."
  282. assert error_message in str(error.value)
  283. def test_cutmix_batch_fail7():
  284. """
  285. Test CutMixBatch op
  286. We expect this to fail because labels are not in one-hot format
  287. """
  288. logger.info("test_cutmix_batch_fail7")
  289. # CutMixBatch Images
  290. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  291. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  292. data1 = data1.batch(5, drop_remainder=True)
  293. data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  294. with pytest.raises(RuntimeError) as error:
  295. images_cutmix = np.array([])
  296. for idx, (image, _) in enumerate(data1):
  297. if idx == 0:
  298. images_cutmix = image
  299. else:
  300. images_cutmix = np.append(images_cutmix, image, axis=0)
  301. error_message = "CutMixBatch: Label's must be in one-hot format and in a batch"
  302. assert error_message in str(error.value)
  303. def test_cutmix_batch_fail8():
  304. """
  305. Test CutMixBatch Fail 8
  306. We expect this to fail because alpha is zero
  307. """
  308. logger.info("test_cutmix_batch_fail8")
  309. # CutMixBatch Images
  310. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  311. one_hot_op = data_trans.OneHot(num_classes=10)
  312. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  313. with pytest.raises(ValueError) as error:
  314. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.0)
  315. error_message = "Input is not within the required interval"
  316. assert error_message in str(error.value)
  317. if __name__ == "__main__":
  318. test_cutmix_batch_success1(plot=True)
  319. test_cutmix_batch_success2(plot=True)
  320. test_cutmix_batch_success3(plot=True)
  321. test_cutmix_batch_nchw_md5()
  322. test_cutmix_batch_nhwc_md5()
  323. test_cutmix_batch_fail1()
  324. test_cutmix_batch_fail2()
  325. test_cutmix_batch_fail3()
  326. test_cutmix_batch_fail4()
  327. test_cutmix_batch_fail5()
  328. test_cutmix_batch_fail6()
  329. test_cutmix_batch_fail7()
  330. test_cutmix_batch_fail8()