You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cutmix_batch_op.py 12 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing the CutMixBatch op in DE
  17. """
  18. import numpy as np
  19. import pytest
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.transforms.vision.c_transforms as vision
  22. import mindspore.dataset.transforms.c_transforms as data_trans
  23. import mindspore.dataset.transforms.vision.utils as mode
  24. from mindspore import log as logger
  25. from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
  26. config_get_set_num_parallel_workers
  27. DATA_DIR = "../data/dataset/testCifar10Data"
  28. GENERATE_GOLDEN = False
  29. def test_cutmix_batch_success1(plot=False):
  30. """
  31. Test CutMixBatch op with specified alpha and prob parameters on a batch of CHW images
  32. """
  33. logger.info("test_cutmix_batch_success1")
  34. # Original Images
  35. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  36. ds_original = ds_original.batch(5, drop_remainder=True)
  37. images_original = None
  38. for idx, (image, _) in enumerate(ds_original):
  39. if idx == 0:
  40. images_original = image
  41. else:
  42. images_original = np.append(images_original, image, axis=0)
  43. # CutMix Images
  44. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  45. hwc2chw_op = vision.HWC2CHW()
  46. data1 = data1.map(input_columns=["image"], operations=hwc2chw_op)
  47. one_hot_op = data_trans.OneHot(num_classes=10)
  48. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  49. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW, 2.0, 0.5)
  50. data1 = data1.batch(5, drop_remainder=True)
  51. data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  52. images_cutmix = None
  53. for idx, (image, _) in enumerate(data1):
  54. if idx == 0:
  55. images_cutmix = image.transpose(0, 2, 3, 1)
  56. else:
  57. images_cutmix = np.append(images_cutmix, image.transpose(0, 2, 3, 1), axis=0)
  58. if plot:
  59. visualize_list(images_original, images_cutmix)
  60. num_samples = images_original.shape[0]
  61. mse = np.zeros(num_samples)
  62. for i in range(num_samples):
  63. mse[i] = diff_mse(images_cutmix[i], images_original[i])
  64. logger.info("MSE= {}".format(str(np.mean(mse))))
  65. def test_cutmix_batch_success2(plot=False):
  66. """
  67. Test CutMixBatch op with default values for alpha and prob on a batch of HWC images
  68. """
  69. logger.info("test_cutmix_batch_success2")
  70. # Original Images
  71. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  72. ds_original = ds_original.batch(5, drop_remainder=True)
  73. images_original = None
  74. for idx, (image, _) in enumerate(ds_original):
  75. if idx == 0:
  76. images_original = image
  77. else:
  78. images_original = np.append(images_original, image, axis=0)
  79. # CutMix Images
  80. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  81. one_hot_op = data_trans.OneHot(num_classes=10)
  82. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  83. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  84. data1 = data1.batch(5, drop_remainder=True)
  85. data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  86. images_cutmix = None
  87. for idx, (image, _) in enumerate(data1):
  88. if idx == 0:
  89. images_cutmix = image
  90. else:
  91. images_cutmix = np.append(images_cutmix, image, axis=0)
  92. if plot:
  93. visualize_list(images_original, images_cutmix)
  94. num_samples = images_original.shape[0]
  95. mse = np.zeros(num_samples)
  96. for i in range(num_samples):
  97. mse[i] = diff_mse(images_cutmix[i], images_original[i])
  98. logger.info("MSE= {}".format(str(np.mean(mse))))
  99. def test_cutmix_batch_nhwc_md5():
  100. """
  101. Test CutMixBatch on a batch of HWC images with MD5:
  102. """
  103. logger.info("test_cutmix_batch_nhwc_md5")
  104. original_seed = config_get_set_seed(0)
  105. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  106. # CutMixBatch Images
  107. data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  108. one_hot_op = data_trans.OneHot(num_classes=10)
  109. data = data.map(input_columns=["label"], operations=one_hot_op)
  110. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  111. data = data.batch(5, drop_remainder=True)
  112. data = data.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  113. filename = "cutmix_batch_c_nhwc_result.npz"
  114. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  115. # Restore config setting
  116. ds.config.set_seed(original_seed)
  117. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  118. def test_cutmix_batch_nchw_md5():
  119. """
  120. Test CutMixBatch on a batch of CHW images with MD5:
  121. """
  122. logger.info("test_cutmix_batch_nchw_md5")
  123. original_seed = config_get_set_seed(0)
  124. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  125. # CutMixBatch Images
  126. data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  127. hwc2chw_op = vision.HWC2CHW()
  128. data = data.map(input_columns=["image"], operations=hwc2chw_op)
  129. one_hot_op = data_trans.OneHot(num_classes=10)
  130. data = data.map(input_columns=["label"], operations=one_hot_op)
  131. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
  132. data = data.batch(5, drop_remainder=True)
  133. data = data.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  134. filename = "cutmix_batch_c_nchw_result.npz"
  135. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  136. # Restore config setting
  137. ds.config.set_seed(original_seed)
  138. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  139. def test_cutmix_batch_fail1():
  140. """
  141. Test CutMixBatch Fail 1
  142. We expect this to fail because the images and labels are not batched
  143. """
  144. logger.info("test_cutmix_batch_fail1")
  145. # CutMixBatch Images
  146. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  147. one_hot_op = data_trans.OneHot(num_classes=10)
  148. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  149. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  150. with pytest.raises(RuntimeError) as error:
  151. data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  152. for idx, (image, _) in enumerate(data1):
  153. if idx == 0:
  154. images_cutmix = image
  155. else:
  156. images_cutmix = np.append(images_cutmix, image, axis=0)
  157. error_message = "You must batch before calling CutMixBatch"
  158. assert error_message in str(error.value)
  159. def test_cutmix_batch_fail2():
  160. """
  161. Test CutMixBatch Fail 2
  162. We expect this to fail because alpha is negative
  163. """
  164. logger.info("test_cutmix_batch_fail2")
  165. # CutMixBatch Images
  166. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  167. one_hot_op = data_trans.OneHot(num_classes=10)
  168. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  169. with pytest.raises(ValueError) as error:
  170. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, -1)
  171. error_message = "Input is not within the required interval"
  172. assert error_message in str(error.value)
  173. def test_cutmix_batch_fail3():
  174. """
  175. Test CutMixBatch Fail 2
  176. We expect this to fail because prob is larger than 1
  177. """
  178. logger.info("test_cutmix_batch_fail3")
  179. # CutMixBatch Images
  180. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  181. one_hot_op = data_trans.OneHot(num_classes=10)
  182. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  183. with pytest.raises(ValueError) as error:
  184. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, 2)
  185. error_message = "Input is not within the required interval"
  186. assert error_message in str(error.value)
  187. def test_cutmix_batch_fail4():
  188. """
  189. Test CutMixBatch Fail 2
  190. We expect this to fail because prob is negative
  191. """
  192. logger.info("test_cutmix_batch_fail4")
  193. # CutMixBatch Images
  194. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  195. one_hot_op = data_trans.OneHot(num_classes=10)
  196. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  197. with pytest.raises(ValueError) as error:
  198. vision.CutMixBatch(mode.ImageBatchFormat.NHWC, 1, -1)
  199. error_message = "Input is not within the required interval"
  200. assert error_message in str(error.value)
  201. def test_cutmix_batch_fail5():
  202. """
  203. Test CutMixBatch op
  204. We expect this to fail because label column is not passed to cutmix_batch
  205. """
  206. logger.info("test_cutmix_batch_fail5")
  207. # CutMixBatch Images
  208. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  209. one_hot_op = data_trans.OneHot(num_classes=10)
  210. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  211. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  212. data1 = data1.batch(5, drop_remainder=True)
  213. data1 = data1.map(input_columns=["image"], operations=cutmix_batch_op)
  214. with pytest.raises(RuntimeError) as error:
  215. images_cutmix = np.array([])
  216. for idx, (image, _) in enumerate(data1):
  217. if idx == 0:
  218. images_cutmix = image
  219. else:
  220. images_cutmix = np.append(images_cutmix, image, axis=0)
  221. error_message = "Both images and labels columns are required"
  222. assert error_message in str(error.value)
  223. def test_cutmix_batch_fail6():
  224. """
  225. Test CutMixBatch op
  226. We expect this to fail because image_batch_format passed to CutMixBatch doesn't match the format of the images
  227. """
  228. logger.info("test_cutmix_batch_fail6")
  229. # CutMixBatch Images
  230. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  231. one_hot_op = data_trans.OneHot(num_classes=10)
  232. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  233. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NCHW)
  234. data1 = data1.batch(5, drop_remainder=True)
  235. data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  236. with pytest.raises(RuntimeError) as error:
  237. images_cutmix = np.array([])
  238. for idx, (image, _) in enumerate(data1):
  239. if idx == 0:
  240. images_cutmix = image
  241. else:
  242. images_cutmix = np.append(images_cutmix, image, axis=0)
  243. error_message = "CutMixBatch: Image doesn't match the given image format."
  244. assert error_message in str(error.value)
  245. def test_cutmix_batch_fail7():
  246. """
  247. Test CutMixBatch op
  248. We expect this to fail because labels are not in one-hot format
  249. """
  250. logger.info("test_cutmix_batch_fail7")
  251. # CutMixBatch Images
  252. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  253. cutmix_batch_op = vision.CutMixBatch(mode.ImageBatchFormat.NHWC)
  254. data1 = data1.batch(5, drop_remainder=True)
  255. data1 = data1.map(input_columns=["image", "label"], operations=cutmix_batch_op)
  256. with pytest.raises(RuntimeError) as error:
  257. images_cutmix = np.array([])
  258. for idx, (image, _) in enumerate(data1):
  259. if idx == 0:
  260. images_cutmix = image
  261. else:
  262. images_cutmix = np.append(images_cutmix, image, axis=0)
  263. error_message = "CutMixBatch: Label's must be in one-hot format and in a batch"
  264. assert error_message in str(error.value)
  265. if __name__ == "__main__":
  266. test_cutmix_batch_success1(plot=True)
  267. test_cutmix_batch_success2(plot=True)
  268. test_cutmix_batch_nchw_md5()
  269. test_cutmix_batch_nhwc_md5()
  270. test_cutmix_batch_fail1()
  271. test_cutmix_batch_fail2()
  272. test_cutmix_batch_fail3()
  273. test_cutmix_batch_fail4()
  274. test_cutmix_batch_fail5()
  275. test_cutmix_batch_fail6()
  276. test_cutmix_batch_fail7()