You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_sbu.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Test USPS dataset operators
  17. """
  18. import os
  19. import matplotlib.pyplot as plt
  20. import numpy as np
  21. import pytest
  22. from PIL import Image
  23. import mindspore.dataset as ds
  24. import mindspore.dataset.vision.c_transforms as vision
  25. from mindspore import log as logger
  26. DATA_DIR = "../data/dataset/testSBUDataset"
  27. WRONG_DIR = "../data/dataset/testMnistData"
  28. def load_sbu(path):
  29. """
  30. load SBU data
  31. """
  32. images = []
  33. captions = []
  34. file1 = os.path.realpath(os.path.join(path, 'SBU_captioned_photo_dataset_urls.txt'))
  35. file2 = os.path.realpath(os.path.join(path, 'SBU_captioned_photo_dataset_captions.txt'))
  36. for line1, line2 in zip(open(file1), open(file2)):
  37. url = line1.rstrip()
  38. image = url[23:].replace("/", "_")
  39. filename = os.path.join(path, 'sbu_images', image)
  40. if os.path.exists(filename):
  41. caption = line2.rstrip()
  42. images.append(np.asarray(Image.open(filename).convert('RGB')).astype(np.uint8))
  43. captions.append(caption)
  44. return images, captions
  45. def visualize_dataset(images, captions):
  46. """
  47. Helper function to visualize the dataset samples
  48. """
  49. num_samples = len(images)
  50. for i in range(num_samples):
  51. plt.subplot(1, num_samples, i + 1)
  52. plt.imshow(images[i].squeeze())
  53. plt.title(captions[i])
  54. plt.show()
  55. def test_sbu_content_check():
  56. """
  57. Validate SBUDataset image readings
  58. """
  59. logger.info("Test SBUDataset Op with content check")
  60. dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=50, shuffle=False)
  61. images, captions = load_sbu(DATA_DIR)
  62. num_iter = 0
  63. # in this example, each dictionary has keys "image" and "caption"
  64. for i, data in enumerate(dataset.create_dict_iterator(num_epochs=1, output_numpy=True)):
  65. assert data["image"].shape == images[i].shape
  66. assert data["caption"].item().decode("utf8") == captions[i]
  67. num_iter += 1
  68. assert num_iter == 5
  69. def test_sbu_case():
  70. """
  71. Validate SBUDataset cases
  72. """
  73. dataset = ds.SBUDataset(DATA_DIR, decode=True)
  74. dataset = dataset.map(operations=[vision.Resize((224, 224))], input_columns=["image"])
  75. repeat_num = 4
  76. dataset = dataset.repeat(repeat_num)
  77. batch_size = 2
  78. dataset = dataset.batch(batch_size, drop_remainder=True, pad_info={})
  79. num = 0
  80. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  81. num += 1
  82. # 4 x 5 / 2
  83. assert num == 10
  84. dataset = ds.SBUDataset(DATA_DIR, decode=False)
  85. dataset = dataset.map(operations=[vision.Decode(rgb=True), vision.Resize((224, 224))], input_columns=["image"])
  86. repeat_num = 4
  87. dataset = dataset.repeat(repeat_num)
  88. batch_size = 2
  89. dataset = dataset.batch(batch_size, drop_remainder=True, pad_info={})
  90. num = 0
  91. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  92. num += 1
  93. # 4 x 5 / 2
  94. assert num == 10
  95. def test_sbu_basic():
  96. """
  97. Validate SBUDataset
  98. """
  99. logger.info("Test SBUDataset Op")
  100. # case 1: test loading whole dataset
  101. dataset = ds.SBUDataset(DATA_DIR, decode=True)
  102. num_iter = 0
  103. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  104. num_iter += 1
  105. assert num_iter == 5
  106. # case 2: test num_samples
  107. dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
  108. num_iter = 0
  109. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  110. num_iter += 1
  111. assert num_iter == 5
  112. # case 3: test repeat
  113. dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
  114. dataset = dataset.repeat(5)
  115. num_iter = 0
  116. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  117. num_iter += 1
  118. assert num_iter == 25
  119. # case 4: test batch
  120. dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
  121. assert dataset.get_dataset_size() == 5
  122. assert dataset.get_batch_size() == 1
  123. num_iter = 0
  124. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  125. num_iter += 1
  126. assert num_iter == 5
  127. # case 5: test get_class_indexing
  128. dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
  129. assert dataset.get_class_indexing() == {}
  130. # case 6: test get_col_names
  131. dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=5)
  132. assert dataset.get_col_names() == ["image", "caption"]
  133. def test_sbu_sequential_sampler():
  134. """
  135. Test SBUDataset with SequentialSampler
  136. """
  137. logger.info("Test SBUDataset Op with SequentialSampler")
  138. num_samples = 5
  139. sampler = ds.SequentialSampler(num_samples=num_samples)
  140. dataset_1 = ds.SBUDataset(DATA_DIR, decode=True, sampler=sampler)
  141. dataset_2 = ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_samples=num_samples)
  142. num_iter = 0
  143. for item1, item2 in zip(dataset_1.create_dict_iterator(num_epochs=1, output_numpy=True),
  144. dataset_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  145. np.testing.assert_array_equal(item1["caption"], item2["caption"])
  146. num_iter += 1
  147. assert num_iter == num_samples
  148. def test_sbu_exception():
  149. """
  150. Test error cases for SBUDataset
  151. """
  152. logger.info("Test error cases for SBUDataset")
  153. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  154. with pytest.raises(RuntimeError, match=error_msg_1):
  155. ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, sampler=ds.SequentialSampler())
  156. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  157. with pytest.raises(RuntimeError, match=error_msg_2):
  158. ds.SBUDataset(DATA_DIR, decode=True, sampler=ds.SequentialSampler(), num_shards=2, shard_id=0)
  159. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  160. with pytest.raises(RuntimeError, match=error_msg_3):
  161. ds.SBUDataset(DATA_DIR, decode=True, num_shards=10)
  162. error_msg_4 = "shard_id is specified but num_shards is not"
  163. with pytest.raises(RuntimeError, match=error_msg_4):
  164. ds.SBUDataset(DATA_DIR, decode=True, shard_id=0)
  165. error_msg_5 = "Input shard_id is not within the required interval"
  166. with pytest.raises(ValueError, match=error_msg_5):
  167. ds.SBUDataset(DATA_DIR, decode=True, num_shards=5, shard_id=-1)
  168. with pytest.raises(ValueError, match=error_msg_5):
  169. ds.SBUDataset(DATA_DIR, decode=True, num_shards=5, shard_id=5)
  170. with pytest.raises(ValueError, match=error_msg_5):
  171. ds.SBUDataset(DATA_DIR, decode=True, num_shards=2, shard_id=5)
  172. error_msg_6 = "num_parallel_workers exceeds"
  173. with pytest.raises(ValueError, match=error_msg_6):
  174. ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=0)
  175. with pytest.raises(ValueError, match=error_msg_6):
  176. ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=256)
  177. with pytest.raises(ValueError, match=error_msg_6):
  178. ds.SBUDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=-2)
  179. error_msg_7 = "Argument shard_id"
  180. with pytest.raises(TypeError, match=error_msg_7):
  181. ds.SBUDataset(DATA_DIR, decode=True, num_shards=2, shard_id="0")
  182. def exception_func(item):
  183. raise Exception("Error occur!")
  184. error_msg_8 = "The corresponding data files"
  185. with pytest.raises(RuntimeError, match=error_msg_8):
  186. dataset = ds.SBUDataset(DATA_DIR, decode=True)
  187. dataset = dataset.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  188. for _ in dataset.__iter__():
  189. pass
  190. with pytest.raises(RuntimeError, match=error_msg_8):
  191. dataset = ds.SBUDataset(DATA_DIR, decode=True)
  192. dataset = dataset.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
  193. for _ in dataset.__iter__():
  194. pass
  195. error_msg_9 = "does not exist or permission denied"
  196. with pytest.raises(ValueError, match=error_msg_9):
  197. dataset = ds.SBUDataset(WRONG_DIR, decode=True)
  198. for _ in dataset.__iter__():
  199. pass
  200. error_msg_10 = "Argument decode with value"
  201. with pytest.raises(TypeError, match=error_msg_10):
  202. dataset = ds.SBUDataset(DATA_DIR, decode="not_bool")
  203. for _ in dataset.__iter__():
  204. pass
  205. def test_sbu_visualize(plot=False):
  206. """
  207. Visualize SBUDataset results
  208. """
  209. logger.info("Test SBUDataset visualization")
  210. dataset = ds.SBUDataset(DATA_DIR, decode=True, num_samples=10, shuffle=False)
  211. num_iter = 0
  212. image_list, caption_list = [], []
  213. for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  214. image = item["image"]
  215. caption = item["caption"].item().decode("utf8")
  216. image_list.append(image)
  217. caption_list.append("caption {}".format(caption))
  218. assert isinstance(image, np.ndarray)
  219. assert image.dtype == np.uint8
  220. assert isinstance(caption, str)
  221. num_iter += 1
  222. assert num_iter == 5
  223. if plot:
  224. visualize_dataset(image_list, caption_list)
  225. def test_sbu_decode():
  226. """
  227. Validate SBUDataset image readings
  228. """
  229. logger.info("Test SBUDataset decode flag")
  230. sampler = ds.SequentialSampler(num_samples=50)
  231. dataset = ds.SBUDataset(dataset_dir=DATA_DIR, decode=False, sampler=sampler)
  232. dataset_1 = dataset.map(operations=[vision.Decode(rgb=True)], input_columns=["image"])
  233. dataset_2 = ds.SBUDataset(dataset_dir=DATA_DIR, decode=True, sampler=sampler)
  234. num_iter = 0
  235. for item1, item2 in zip(dataset_1.create_dict_iterator(num_epochs=1, output_numpy=True),
  236. dataset_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  237. np.testing.assert_array_equal(item1["caption"], item2["caption"])
  238. num_iter += 1
  239. assert num_iter == 5
  240. if __name__ == '__main__':
  241. test_sbu_content_check()
  242. test_sbu_basic()
  243. test_sbu_case()
  244. test_sbu_sequential_sampler()
  245. test_sbu_exception()
  246. test_sbu_visualize(plot=True)
  247. test_sbu_decode()