You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_mixup_op.py 8.6 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing the MixUpBatch op in DE
  17. """
  18. import numpy as np
  19. import pytest
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.transforms.vision.c_transforms as vision
  22. import mindspore.dataset.transforms.c_transforms as data_trans
  23. from mindspore import log as logger
  24. from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
  25. config_get_set_num_parallel_workers
  26. DATA_DIR = "../data/dataset/testCifar10Data"
  27. GENERATE_GOLDEN = False
  28. def test_mixup_batch_success1(plot=False):
  29. """
  30. Test MixUpBatch op with specified alpha parameter
  31. """
  32. logger.info("test_mixup_batch_success1")
  33. # Original Images
  34. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  35. ds_original = ds_original.batch(5, drop_remainder=True)
  36. images_original = None
  37. for idx, (image, _) in enumerate(ds_original):
  38. if idx == 0:
  39. images_original = image
  40. else:
  41. images_original = np.append(images_original, image, axis=0)
  42. # MixUp Images
  43. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  44. one_hot_op = data_trans.OneHot(num_classes=10)
  45. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  46. mixup_batch_op = vision.MixUpBatch(2)
  47. data1 = data1.batch(5, drop_remainder=True)
  48. data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op)
  49. images_mixup = None
  50. for idx, (image, _) in enumerate(data1):
  51. if idx == 0:
  52. images_mixup = image
  53. else:
  54. images_mixup = np.append(images_mixup, image, axis=0)
  55. if plot:
  56. visualize_list(images_original, images_mixup)
  57. num_samples = images_original.shape[0]
  58. mse = np.zeros(num_samples)
  59. for i in range(num_samples):
  60. mse[i] = diff_mse(images_mixup[i], images_original[i])
  61. logger.info("MSE= {}".format(str(np.mean(mse))))
  62. def test_mixup_batch_success2(plot=False):
  63. """
  64. Test MixUpBatch op without specified alpha parameter.
  65. Alpha parameter will be selected by default in this case
  66. """
  67. logger.info("test_mixup_batch_success2")
  68. # Original Images
  69. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  70. ds_original = ds_original.batch(5, drop_remainder=True)
  71. images_original = None
  72. for idx, (image, _) in enumerate(ds_original):
  73. if idx == 0:
  74. images_original = image
  75. else:
  76. images_original = np.append(images_original, image, axis=0)
  77. # MixUp Images
  78. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  79. one_hot_op = data_trans.OneHot(num_classes=10)
  80. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  81. mixup_batch_op = vision.MixUpBatch()
  82. data1 = data1.batch(5, drop_remainder=True)
  83. data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op)
  84. images_mixup = np.array([])
  85. for idx, (image, _) in enumerate(data1):
  86. if idx == 0:
  87. images_mixup = image
  88. else:
  89. images_mixup = np.append(images_mixup, image, axis=0)
  90. if plot:
  91. visualize_list(images_original, images_mixup)
  92. num_samples = images_original.shape[0]
  93. mse = np.zeros(num_samples)
  94. for i in range(num_samples):
  95. mse[i] = diff_mse(images_mixup[i], images_original[i])
  96. logger.info("MSE= {}".format(str(np.mean(mse))))
  97. def test_mixup_batch_md5():
  98. """
  99. Test MixUpBatch with MD5:
  100. """
  101. logger.info("test_mixup_batch_md5")
  102. original_seed = config_get_set_seed(0)
  103. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  104. # MixUp Images
  105. data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  106. one_hot_op = data_trans.OneHot(num_classes=10)
  107. data = data.map(input_columns=["label"], operations=one_hot_op)
  108. mixup_batch_op = vision.MixUpBatch()
  109. data = data.batch(5, drop_remainder=True)
  110. data = data.map(input_columns=["image", "label"], operations=mixup_batch_op)
  111. filename = "mixup_batch_c_result.npz"
  112. save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
  113. # Restore config setting
  114. ds.config.set_seed(original_seed)
  115. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  116. def test_mixup_batch_fail1():
  117. """
  118. Test MixUpBatch Fail 1
  119. We expect this to fail because the images and labels are not batched
  120. """
  121. logger.info("test_mixup_batch_fail1")
  122. # Original Images
  123. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  124. ds_original = ds_original.batch(5)
  125. images_original = np.array([])
  126. for idx, (image, _) in enumerate(ds_original):
  127. if idx == 0:
  128. images_original = image
  129. else:
  130. images_original = np.append(images_original, image, axis=0)
  131. # MixUp Images
  132. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  133. one_hot_op = data_trans.OneHot(num_classes=10)
  134. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  135. mixup_batch_op = vision.MixUpBatch(0.1)
  136. with pytest.raises(RuntimeError) as error:
  137. data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op)
  138. for idx, (image, _) in enumerate(data1):
  139. if idx == 0:
  140. images_mixup = image
  141. else:
  142. images_mixup = np.append(images_mixup, image, axis=0)
  143. error_message = "You must batch before calling MixUp"
  144. assert error_message in str(error.value)
  145. def test_mixup_batch_fail2():
  146. """
  147. Test MixUpBatch Fail 2
  148. We expect this to fail because alpha is negative
  149. """
  150. logger.info("test_mixup_batch_fail2")
  151. # Original Images
  152. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  153. ds_original = ds_original.batch(5)
  154. images_original = np.array([])
  155. for idx, (image, _) in enumerate(ds_original):
  156. if idx == 0:
  157. images_original = image
  158. else:
  159. images_original = np.append(images_original, image, axis=0)
  160. # MixUp Images
  161. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  162. one_hot_op = data_trans.OneHot(num_classes=10)
  163. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  164. with pytest.raises(ValueError) as error:
  165. vision.MixUpBatch(-1)
  166. error_message = "Input is not within the required interval"
  167. assert error_message in str(error.value)
  168. def test_mixup_batch_fail3():
  169. """
  170. Test MixUpBatch op
  171. We expect this to fail because label column is not passed to mixup_batch
  172. """
  173. # Original Images
  174. ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  175. ds_original = ds_original.batch(5, drop_remainder=True)
  176. images_original = None
  177. for idx, (image, _) in enumerate(ds_original):
  178. if idx == 0:
  179. images_original = image
  180. else:
  181. images_original = np.append(images_original, image, axis=0)
  182. # MixUp Images
  183. data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
  184. one_hot_op = data_trans.OneHot(num_classes=10)
  185. data1 = data1.map(input_columns=["label"], operations=one_hot_op)
  186. mixup_batch_op = vision.MixUpBatch()
  187. data1 = data1.batch(5, drop_remainder=True)
  188. data1 = data1.map(input_columns=["image"], operations=mixup_batch_op)
  189. with pytest.raises(RuntimeError) as error:
  190. images_mixup = np.array([])
  191. for idx, (image, _) in enumerate(data1):
  192. if idx == 0:
  193. images_mixup = image
  194. else:
  195. images_mixup = np.append(images_mixup, image, axis=0)
  196. error_message = "Both images and labels columns are required"
  197. assert error_message in str(error.value)
  198. if __name__ == "__main__":
  199. test_mixup_batch_success1(plot=True)
  200. test_mixup_batch_success2(plot=True)
  201. test_mixup_batch_md5()
  202. test_mixup_batch_fail1()
  203. test_mixup_batch_fail2()
  204. test_mixup_batch_fail3()