You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cutmix_batch_op.py 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing the CutMixBatch op in DE
  17. """
  18. import numpy as np
  19. import pytest
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.transforms.vision.c_transforms as vision
  22. import mindspore.dataset.transforms.c_transforms as data_trans
  23. import mindspore.dataset.transforms.vision.utils as mode
  24. from mindspore import log as logger
  25. from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
  26. config_get_set_num_parallel_workers
  27. DATA_DIR = "../data/dataset/testCifar10Data"
  28. DATA_DIR2 = "../data/dataset/testImageNetData2/train/"
  29. GENERATE_GOLDEN = False
  30. def test_cutmix_batch_success1(plot=False):
  31. """
  32. Test CutMixBatch op with specified alpha and prob parameters on a batch of CHW images
  33. """
  34. logger.info("test_cutmix_batch_success1")
  35. # Original Images
  36. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  37. ds_original = ds_original.batch(5, drop_remainder=True)
  38. images_original = None
  39. for idx, (image, _) in enumerate(ds_original):
  40. if idx == 0:
  41. images_original = image
  42. else:
  43. images_original = np.append(images_original, image, axis=0)
  44. # CutMix Images
  45. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  46. hwc2chw_op = vision.HWC2CHW()
  47. data1 = data1.map(input_columns=["image"], operations=hwc2chw_op)
  48. one_hot_op = data_trans.OneHot(num_classes=10)
  49. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  50. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW, 2.0, 0.5)
  51. data1 = data1.batch(5, drop_remainder=True)
  52. data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  53. images_cutmix = None
  54. for idx, (image, _) in enumerate(data1):
  55. if idx == 0:
  56. images_cutmix = image.transpose(0, 2, 3, 1)
  57. else:
  58. images_cutmix = np.append(images_cutmix, image.transpose(0, 2, 3, 1), axis=0)
  59. if plot:
  60. visualize_list(images_original, images_cutmix)
  61. num_samples = images_original.shape[0]
  62. mse = np.zeros(num_samples)
  63. for i in range(num_samples):
  64. mse[i] = diff_mse(images_cutmix[i], images_original[i])
  65. logger.info("MSE= {}".format(str(np.mean(mse))))
  66. def test_cutmix_batch_success2(plot=False):
  67. """
  68. Test CutMixBatch op with default values for alpha and prob on a batch of rescaled HWC images
  69. """
  70. logger.info("test_cutmix_batch_success2")
  71. # Original Images
  72. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  73. ds_original = ds_original.batch(5, drop_remainder=True)
  74. images_original = None
  75. for idx, (image, _) in enumerate(ds_original):
  76. if idx == 0:
  77. images_original = image
  78. else:
  79. images_original = np.append(images_original, image, axis=0)
  80. # CutMix Images
  81. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  82. one_hot_op = data_trans.OneHot(num_classes=10)
  83. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  84. rescale_op = vision.Rescale((1.0/255.0), 0.0)
  85. data1 = data1.map(input_columns=["image"], operations=rescale_op)
  86. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  87. data1 = data1.batch(5, drop_remainder=True)
  88. data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  89. images_cutmix = None
  90. for idx, (image, _) in enumerate(data1):
  91. if idx == 0:
  92. images_cutmix = image
  93. else:
  94. images_cutmix = np.append(images_cutmix, image, axis=0)
  95. if plot:
  96. visualize_list(images_original, images_cutmix)
  97. num_samples = images_original.shape[0]
  98. mse = np.zeros(num_samples)
  99. for i in range(num_samples):
  100. mse[i] = diff_mse(images_cutmix[i], images_original[i])
  101. logger.info("MSE= {}".format(str(np.mean(mse))))
  102. def test_cutmix_batch_success3(plot=False):
  103. """
  104. Test CutMixBatch op with default values for alpha and prob on a batch of HWC images on ImageFolderDatasetV2
  105. """
  106. logger.info("test_cutmix_batch_success3")
  107. ds_original = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False)
  108. decode_op = vision.Decode()
  109. ds_original = ds_original.map(input_columns=["image"], operations=[decode_op])
  110. ds_original = ds_original.batch(4, pad_info={}, drop_remainder=True)
  111. images_original = None
  112. for idx, (image, _) in enumerate(ds_original):
  113. if idx == 0:
  114. images_original = image
  115. else:
  116. images_original = np.append(images_original, image, axis=0)
  117. # CutMix Images
  118. data1 = ds.ImageFolderDatasetV2(dataset_dir=DATA_DIR2, shuffle=False)
  119. decode_op = vision.Decode()
  120. data1 = data1.map(input_columns=["image"], operations=[decode_op])
  121. one_hot_op = data_trans.OneHot(num_classes=10)
  122. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  123. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  124. data1 = data1.batch(4, pad_info={}, drop_remainder=True)
  125. data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  126. images_cutmix = None
  127. for idx, (image, _) in enumerate(data1):
  128. if idx == 0:
  129. images_cutmix = image
  130. else:
  131. images_cutmix = np.append(images_cutmix, image, axis=0)
  132. if plot:
  133. visualize_list(images_original, images_cutmix)
  134. num_samples = images_original.shape[0]
  135. mse = np.zeros(num_samples)
  136. for i in range(num_samples):
  137. mse[i] = diff_mse(images_cutmix[i], images_original[i])
  138. logger.info("MSE= {}".format(str(np.mean(mse))))
  139. def test_cutmix_batch_nhwc_md5():
  140. """
  141. Test CutMixBatch on a batch of HWC images with MD5:
  142. """
  143. logger.info("test_cutmix_batch_nhwc_md5")
  144. original_seed = config_get_set_seed(0)
  145. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  146. # CutMixBatch Images
  147. data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  148. one_hot_op = data_trans.OneHot(num_classes=10)
  149. data = data.map(input_columns=["label"], operations=one_hot_op)
  150. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  151. data = data.batch(5, drop_remainder=True)
  152. data = data.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  153. filename = "cutmix_batch_c_nhwc_result.npz"
  154. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  155. # Restore config setting
  156. ds.config.set_seed(original_seed)
  157. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  158. def test_cutmix_batch_nchw_md5():
  159. """
  160. Test CutMixBatch on a batch of CHW images with MD5:
  161. """
  162. logger.info("test_cutmix_batch_nchw_md5")
  163. original_seed = config_get_set_seed(0)
  164. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  165. # CutMixBatch Images
  166. data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  167. hwc2chw_op = vision.HWC2CHW()
  168. data = data.map(input_columns=["image"], operations=hwc2chw_op)
  169. one_hot_op = data_trans.OneHot(num_classes=10)
  170. data = data.map(input_columns=["label"], operations=one_hot_op)
  171. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
  172. data = data.batch(5, drop_remainder=True)
  173. data = data.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  174. filename = "cutmix_batch_c_nchw_result.npz"
  175. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  176. # Restore config setting
  177. ds.config.set_seed(original_seed)
  178. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  179. def test_cutmix_batch_fail1():
  180. """
  181. Test CutMixBatch Fail 1
  182. We expect this to fail because the images and labels are not batched
  183. """
  184. logger.info("test_cutmix_batch_fail1")
  185. # CutMixBatch Images
  186. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  187. one_hot_op = data_trans.OneHot(num_classes=10)
  188. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  189. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  190. with pytest.raises(RuntimeError) as error:
  191. data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  192. for idx, (image, _) in enumerate(data1):
  193. if idx == 0:
  194. images_cutmix = image
  195. else:
  196. images_cutmix = np.append(images_cutmix, image, axis=0)
  197. error_message = "You must make sure images are HWC or CHW and batch "
  198. assert error_message in str(error.value)
  199. def test_cutmix_batch_fail2():
  200. """
  201. Test CutMixBatch Fail 2
  202. We expect this to fail because alpha is negative
  203. """
  204. logger.info("test_cutmix_batch_fail2")
  205. # CutMixBatch Images
  206. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  207. one_hot_op = data_trans.OneHot(num_classes=10)
  208. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  209. with pytest.raises(ValueError) as error:
  210. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, -1)
  211. error_message = "Input is not within the required interval"
  212. assert error_message in str(error.value)
  213. def test_cutmix_batch_fail3():
  214. """
  215. Test CutMixBatch Fail 2
  216. We expect this to fail because prob is larger than 1
  217. """
  218. logger.info("test_cutmix_batch_fail3")
  219. # CutMixBatch Images
  220. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  221. one_hot_op = data_trans.OneHot(num_classes=10)
  222. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  223. with pytest.raises(ValueError) as error:
  224. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, 2)
  225. error_message = "Input is not within the required interval"
  226. assert error_message in str(error.value)
  227. def test_cutmix_batch_fail4():
  228. """
  229. Test CutMixBatch Fail 2
  230. We expect this to fail because prob is negative
  231. """
  232. logger.info("test_cutmix_batch_fail4")
  233. # CutMixBatch Images
  234. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  235. one_hot_op = data_trans.OneHot(num_classes=10)
  236. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  237. with pytest.raises(ValueError) as error:
  238. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, -1)
  239. error_message = "Input is not within the required interval"
  240. assert error_message in str(error.value)
  241. def test_cutmix_batch_fail5():
  242. """
  243. Test CutMixBatch op
  244. We expect this to fail because label column is not passed to cutmix_batch
  245. """
  246. logger.info("test_cutmix_batch_fail5")
  247. # CutMixBatch Images
  248. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  249. one_hot_op = data_trans.OneHot(num_classes=10)
  250. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  251. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  252. data1 = data1.batch(5, drop_remainder=True)
  253. data1 = data1.map(input_columns=["image"], operations=cutmix_batch_op)
  254. with pytest.raises(RuntimeError) as error:
  255. images_cutmix = np.array([])
  256. for idx, (image, _) in enumerate(data1):
  257. if idx == 0:
  258. images_cutmix = image
  259. else:
  260. images_cutmix = np.append(images_cutmix, image, axis=0)
  261. error_message = "Both images and labels columns are required"
  262. assert error_message in str(error.value)
  263. def test_cutmix_batch_fail6():
  264. """
  265. Test CutMixBatch op
  266. We expect this to fail because image_batch_format passed to CutMixBatch doesn't match the format of the images
  267. """
  268. logger.info("test_cutmix_batch_fail6")
  269. # CutMixBatch Images
  270. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  271. one_hot_op = data_trans.OneHot(num_classes=10)
  272. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  273. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
  274. data1 = data1.batch(5, drop_remainder=True)
  275. data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  276. with pytest.raises(RuntimeError) as error:
  277. images_cutmix = np.array([])
  278. for idx, (image, _) in enumerate(data1):
  279. if idx == 0:
  280. images_cutmix = image
  281. else:
  282. images_cutmix = np.append(images_cutmix, image, axis=0)
  283. error_message = "CutMixBatch: Image doesn't match the given image format."
  284. assert error_message in str(error.value)
  285. def test_cutmix_batch_fail7():
  286. """
  287. Test CutMixBatch op
  288. We expect this to fail because labels are not in one-hot format
  289. """
  290. logger.info("test_cutmix_batch_fail7")
  291. # CutMixBatch Images
  292. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  293. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  294. data1 = data1.batch(5, drop_remainder=True)
  295. data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  296. with pytest.raises(RuntimeError) as error:
  297. images_cutmix = np.array([])
  298. for idx, (image, _) in enumerate(data1):
  299. if idx == 0:
  300. images_cutmix = image
  301. else:
  302. images_cutmix = np.append(images_cutmix, image, axis=0)
  303. error_message = "CutMixBatch: Label's must be in one-hot format and in a batch"
  304. assert error_message in str(error.value)
  305. def test_cutmix_batch_fail8():
  306. """
  307. Test CutMixBatch Fail 8
  308. We expect this to fail because alpha is zero
  309. """
  310. logger.info("test_cutmix_batch_fail8")
  311. # CutMixBatch Images
  312. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  313. one_hot_op = data_trans.OneHot(num_classes=10)
  314. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  315. with pytest.raises(ValueError) as error:
  316. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 0.0)
  317. error_message = "Input is not within the required interval"
  318. assert error_message in str(error.value)
  319. if __name__ == "__main__":
  320. test_cutmix_batch_success1(plot=True)
  321. test_cutmix_batch_success2(plot=True)
  322. test_cutmix_batch_success3(plot=True)
  323. test_cutmix_batch_nchw_md5()
  324. test_cutmix_batch_nhwc_md5()
  325. test_cutmix_batch_fail1()
  326. test_cutmix_batch_fail2()
  327. test_cutmix_batch_fail3()
  328. test_cutmix_batch_fail4()
  329. test_cutmix_batch_fail5()
  330. test_cutmix_batch_fail6()
  331. test_cutmix_batch_fail7()
  332. test_cutmix_batch_fail8()