You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_fake_image.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Test FakeImage dataset operators
  17. """
  18. import matplotlib.pyplot as plt
  19. import numpy as np
  20. import pytest
  21. import mindspore.dataset as ds
  22. from mindspore import log as logger
  23. num_images = 50
  24. image_size = (28, 28, 3)
  25. num_classes = 10
  26. base_seed = 0
  27. def visualize_dataset(images, labels):
  28. """
  29. Helper function to visualize the dataset samples
  30. """
  31. num_samples = len(images)
  32. for i in range(num_samples):
  33. plt.subplot(1, num_samples, i + 1)
  34. plt.imshow(images[i].squeeze(), cmap=plt.cm.gray)
  35. plt.title(labels[i])
  36. plt.show()
  37. def test_fake_image_basic():
  38. """
  39. Feature: FakeImage
  40. Description: test basic usage of FakeImage
  41. Expectation: the dataset is as expected
  42. """
  43. logger.info("Test FakeImageDataset Op")
  44. # case 1: test loading whole dataset
  45. train_data = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed)
  46. num_iter1 = 0
  47. for _ in train_data.create_dict_iterator(num_epochs=1):
  48. num_iter1 += 1
  49. assert num_iter1 == num_images
  50. # case 2: test num_samples
  51. train_data = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_samples=4)
  52. num_iter2 = 0
  53. for _ in train_data.create_dict_iterator(num_epochs=1):
  54. num_iter2 += 1
  55. assert num_iter2 == 4
  56. # case 3: test repeat
  57. train_data = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_samples=4)
  58. train_data = train_data.repeat(5)
  59. num_iter3 = 0
  60. for _ in train_data.create_dict_iterator(num_epochs=1):
  61. num_iter3 += 1
  62. assert num_iter3 == 20
  63. # case 4: test batch with drop_remainder=False, get_dataset_size, get_batch_size, get_col_names
  64. train_data = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_samples=4)
  65. assert train_data.get_dataset_size() == 4
  66. assert train_data.get_batch_size() == 1
  67. assert train_data.get_col_names() == ['image', 'label']
  68. train_data = train_data.batch(batch_size=3) # drop_remainder is default to be False
  69. assert train_data.get_dataset_size() == 2
  70. assert train_data.get_batch_size() == 3
  71. num_iter4 = 0
  72. for _ in train_data.create_dict_iterator(num_epochs=1):
  73. num_iter4 += 1
  74. assert num_iter4 == 2
  75. # case 5: test batch with drop_remainder=True
  76. train_data = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_samples=4)
  77. assert train_data.get_dataset_size() == 4
  78. assert train_data.get_batch_size() == 1
  79. train_data = train_data.batch(batch_size=3, drop_remainder=True) # the rest of incomplete batch will be dropped
  80. assert train_data.get_dataset_size() == 1
  81. assert train_data.get_batch_size() == 3
  82. num_iter5 = 0
  83. for _ in train_data.create_dict_iterator(num_epochs=1):
  84. num_iter5 += 1
  85. assert num_iter5 == 1
  86. def test_fake_image_pk_sampler():
  87. """
  88. Feature: FakeImage
  89. Description: test FakeImageDataset with PKSamplere
  90. Expectation: the results are as expected
  91. """
  92. logger.info("Test FakeImageDataset Op with PKSampler")
  93. golden = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9]
  94. #correlation with num_classes
  95. sampler = ds.PKSampler(3)
  96. train_data = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, sampler=sampler)
  97. num_iter = 0
  98. label_list = []
  99. for item in train_data.create_dict_iterator(num_epochs=1, output_numpy=True):
  100. label_list.append(item["label"])
  101. num_iter += 1
  102. np.testing.assert_array_equal(golden, label_list)
  103. assert num_iter == 30
  104. def test_fake_image_sequential_sampler():
  105. """
  106. Feature: FakeImage
  107. Description: test FakeImageDataset with SequentialSampler
  108. Expectation: the results are as expected
  109. """
  110. logger.info("Test FakeImageDataset Op with SequentialSampler")
  111. num_samples = 50
  112. sampler = ds.SequentialSampler(num_samples=num_samples)
  113. train_data1 = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, sampler=sampler)
  114. train_data2 = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, shuffle=False,
  115. num_samples=num_samples)
  116. label_list1, label_list2 = [], []
  117. num_iter = 0
  118. for item1, item2 in zip(train_data1.create_dict_iterator(num_epochs=1),
  119. train_data2.create_dict_iterator(num_epochs=1)):
  120. label_list1.append(item1["label"].asnumpy())
  121. label_list2.append(item2["label"].asnumpy())
  122. num_iter += 1
  123. np.testing.assert_array_equal(label_list1, label_list2)
  124. assert num_iter == num_samples
  125. def test_fake_image_exception():
  126. """
  127. Feature: FakeImage
  128. Description: test error cases for FakeImageDataset
  129. Expectation: throw exception correctly
  130. """
  131. logger.info("Test error cases for FakeImageDataset")
  132. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  133. with pytest.raises(RuntimeError, match=error_msg_1):
  134. ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, shuffle=False, sampler=ds.PKSampler(3))
  135. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  136. with pytest.raises(RuntimeError, match=error_msg_2):
  137. ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, sampler=ds.PKSampler(3), num_shards=2,
  138. shard_id=0)
  139. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  140. with pytest.raises(RuntimeError, match=error_msg_3):
  141. ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_shards=10)
  142. error_msg_4 = "shard_id is specified but num_shards is not"
  143. with pytest.raises(RuntimeError, match=error_msg_4):
  144. ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, shard_id=0)
  145. error_msg_5 = "Input shard_id is not within the required interval"
  146. with pytest.raises(ValueError, match=error_msg_5):
  147. ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_shards=5, shard_id=-1)
  148. with pytest.raises(ValueError, match=error_msg_5):
  149. ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_shards=5, shard_id=5)
  150. with pytest.raises(ValueError, match=error_msg_5):
  151. ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_shards=2, shard_id=5)
  152. error_msg_6 = "num_parallel_workers exceeds"
  153. with pytest.raises(ValueError, match=error_msg_6):
  154. ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, shuffle=False, num_parallel_workers=0)
  155. with pytest.raises(ValueError, match=error_msg_6):
  156. ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, shuffle=False, num_parallel_workers=256)
  157. with pytest.raises(ValueError, match=error_msg_6):
  158. ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, shuffle=False, num_parallel_workers=-2)
  159. error_msg_7 = "Argument shard_id"
  160. with pytest.raises(TypeError, match=error_msg_7):
  161. ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_shards=2, shard_id="0")
  162. def test_fake_image_visualize(plot=False):
  163. """
  164. Feature: FakeImage
  165. Description: test FakeImageDataset visualized results
  166. Expectation: get correct dataset of FakeImage
  167. """
  168. logger.info("Test FakeImageDataset visualization")
  169. train_data = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_samples=10, shuffle=False)
  170. num_iter = 0
  171. image_list, label_list = [], []
  172. for item in train_data.create_dict_iterator(num_epochs=1, output_numpy=True):
  173. image = item["image"]
  174. label = item["label"]
  175. image_list.append(image)
  176. label_list.append("label {}".format(label))
  177. assert isinstance(image, np.ndarray)
  178. assert image.shape == (28, 28, 3)
  179. assert image.dtype == np.uint8
  180. assert label.dtype == np.uint32
  181. num_iter += 1
  182. assert num_iter == 10
  183. if plot:
  184. visualize_dataset(image_list, label_list)
  185. def test_fake_image_num_images():
  186. """
  187. Feature: FakeImage
  188. Description: test FakeImageDataset with num images
  189. Expectation: throw exception correctly or get correct dataset
  190. """
  191. logger.info("Test FakeImageDataset num_images flag")
  192. def test_config(test_num_images):
  193. try:
  194. data = ds.FakeImageDataset(test_num_images, image_size, num_classes, base_seed, shuffle=False)
  195. num_rows = 0
  196. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  197. num_rows += 1
  198. except (ValueError, TypeError, RuntimeError) as e:
  199. return str(e)
  200. return num_rows
  201. assert test_config(num_images) == num_images
  202. assert "Input num_images is not within the required interval of [1, 2147483647]." in test_config(-1)
  203. assert "is not of type [<class 'int'>], but got <class 'str'>." in test_config("10")
  204. def test_fake_image_image_size():
  205. """
  206. Feature: FakeImage
  207. Description: test FakeImageDataset with image size
  208. Expectation: throw exception correctly or get correct dataset
  209. """
  210. logger.info("Test FakeImageDataset image_size flag")
  211. def test_config(test_image_size):
  212. try:
  213. data = ds.FakeImageDataset(num_images, test_image_size, num_classes, base_seed, shuffle=False)
  214. num_rows = 0
  215. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  216. num_rows += 1
  217. except (ValueError, TypeError, RuntimeError) as e:
  218. return str(e)
  219. return num_rows
  220. assert test_config(image_size) == num_images
  221. assert "Argument image_size[0] with value -1 is not of type [<class 'int'>], but got <class 'str'>."\
  222. in test_config(("-1", 28, 3))
  223. assert "image_size should be a list or tuple of length 3, but got 2" in test_config((2, 2))
  224. assert "Input image_size[0] is not within the required interval of [1, 2147483647]." in test_config((-1, 28, 3))
  225. def test_fake_image_num_classes():
  226. """
  227. Feature: FakeImage
  228. Description: test FakeImageDataset with num classes
  229. Expectation: throw exception correctly or get correct dataset
  230. """
  231. logger.info("Test FakeImageDataset num_classes flag")
  232. def test_config(test_num_classes):
  233. try:
  234. data = ds.FakeImageDataset(num_images, image_size, test_num_classes, base_seed, shuffle=False)
  235. num_rows = 0
  236. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  237. num_rows += 1
  238. except (ValueError, TypeError, RuntimeError) as e:
  239. return str(e)
  240. return num_rows
  241. assert test_config(num_classes) == num_images
  242. assert "Input num_classes is not within the required interval of [1, 2147483647]." in test_config(-1)
  243. #should not be negative
  244. assert "is not of type [<class 'int'>], but got <class 'str'>." in test_config("10")
  245. if __name__ == '__main__':
  246. test_fake_image_basic()
  247. test_fake_image_pk_sampler()
  248. test_fake_image_sequential_sampler()
  249. test_fake_image_exception()
  250. test_fake_image_visualize(plot=True)
  251. test_fake_image_num_images()
  252. test_fake_image_image_size()
  253. test_fake_image_num_classes()