You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_stl10.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Test STL10 dataset operators
  17. """
  18. import os
  19. import matplotlib.pyplot as plt
  20. import numpy as np
  21. import pytest
  22. import mindspore.dataset as ds
  23. import mindspore.dataset.vision.c_transforms as vision
  24. from mindspore import log as logger
  25. DATA_DIR = "../data/dataset/testSTL10Data"
  26. WRONG_DIR = "../data/dataset/testMnistData"
  27. def loadfile(path_to_data, path_to_labels=None):
  28. """
  29. Feature: loadfile.
  30. Description: parse stl10 file.
  31. Expectation: get image and label of stl10 dataset.
  32. """
  33. labels = None
  34. if path_to_labels:
  35. with open(os.path.realpath(path_to_labels), 'rb') as f:
  36. labels = np.fromfile(f, dtype=np.uint8) - 1 # 0-based
  37. with open(path_to_data, 'rb') as f:
  38. # read whole file in uint8 chunks
  39. everything = np.fromfile(f, dtype=np.uint8)
  40. images = np.reshape(everything, (-1, 3, 96, 96))
  41. images = np.transpose(images, (0, 1, 3, 2))
  42. return images, labels
  43. def load_stl10(path, usage):
  44. """
  45. Feature: load_stl10.
  46. Description: load stl10.
  47. Expectation: get data of stl10 dataset.
  48. """
  49. assert usage in ["train", "test", "unlabeled", "train+unlabeled", "all"]
  50. if usage == "train":
  51. image_path = os.path.join(path, "train_X.bin")
  52. label_path = os.path.join(path, "train_y.bin")
  53. images, labels = loadfile(image_path, label_path)
  54. elif usage == "train+unlabeled":
  55. image_path = os.path.join(path, "train_X.bin")
  56. label_path = os.path.join(path, "train_y.bin")
  57. images, labels = loadfile(image_path, label_path)
  58. image_path = os.path.join(path, "unlabeled_X.bin")
  59. unlabeled_image, _ = loadfile(image_path)
  60. images = np.concatenate((images, unlabeled_image))
  61. labels = np.concatenate((labels, np.asarray([-1] * unlabeled_image.shape[0])))
  62. elif usage == "unlabeled":
  63. image_path = os.path.join(path, "unlabeled_X.bin")
  64. images, _ = loadfile(image_path)
  65. labels = np.asarray([-1] * images.shape[0])
  66. elif usage == "test":
  67. image_path = os.path.join(path, "test_X.bin")
  68. label_path = os.path.join(path, "test_y.bin")
  69. images, labels = loadfile(image_path, label_path)
  70. elif usage == "all":
  71. image_path = os.path.join(path, "test_X.bin")
  72. label_path = os.path.join(path, "test_y.bin")
  73. images, labels = loadfile(image_path, label_path)
  74. image_path = os.path.join(path, "train_X.bin")
  75. label_path = os.path.join(path, "train_y.bin")
  76. train_image, train_label = loadfile(image_path, label_path)
  77. images = np.concatenate((images, train_image))
  78. labels = np.concatenate((labels, train_label))
  79. image_path = os.path.join(path, "unlabeled_X.bin")
  80. unlabeled_image, _ = loadfile(image_path)
  81. images = np.concatenate((images, unlabeled_image))
  82. labels = np.concatenate((labels, np.asarray([-1] * unlabeled_image.shape[0])))
  83. return images, labels
  84. def visualize_dataset(images, labels):
  85. """
  86. Feature: visualize_dataset.
  87. Description: visualize stl10 dataset.
  88. Expectation: plot images.
  89. """
  90. num_samples = len(images)
  91. for i in range(num_samples):
  92. plt.subplot(1, num_samples, i + 1)
  93. plt.imshow(np.transpose(images[i], (1, 2, 0)))
  94. plt.title(labels[i])
  95. plt.show()
  96. def test_stl10_content_check():
  97. """
  98. Feature: test_stl10_content_check.
  99. Description: validate STL10ataset image readings.
  100. Expectation: get correct number of data and correct content.
  101. """
  102. logger.info("Test STL10Dataset Op with content check")
  103. # 1. train data.
  104. data1 = ds.STL10Dataset(DATA_DIR, usage="train", num_samples=1, shuffle=False)
  105. images, labels = load_stl10(DATA_DIR, "train")
  106. num_iter = 0
  107. # in this example, each dictionary has keys "image" and "label".
  108. for i, d in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
  109. np.testing.assert_array_equal(d["image"], np.transpose(images[i], (1, 2, 0)))
  110. np.testing.assert_array_equal(d["label"], labels[i])
  111. num_iter += 1
  112. assert num_iter == 1
  113. # 2. test data.
  114. data1 = ds.STL10Dataset(DATA_DIR, usage="test", num_samples=1, shuffle=False)
  115. images, labels = load_stl10(DATA_DIR, "test")
  116. num_iter = 0
  117. # in this example, each dictionary has keys "image" and "label".
  118. for i, d in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
  119. np.testing.assert_array_equal(d["image"], np.transpose(images[i], (1, 2, 0)))
  120. np.testing.assert_array_equal(d["label"], labels[i])
  121. num_iter += 1
  122. assert num_iter == 1
  123. # 3. unlabeled data.
  124. data1 = ds.STL10Dataset(DATA_DIR, usage="unlabeled", num_samples=1, shuffle=False)
  125. images, labels = load_stl10(DATA_DIR, "unlabeled")
  126. num_iter = 0
  127. # in this example, each dictionary has keys "image" and "label".
  128. for i, d in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
  129. np.testing.assert_array_equal(d["image"], np.transpose(images[i], (1, 2, 0)))
  130. np.testing.assert_array_equal(d["label"], labels[i])
  131. num_iter += 1
  132. assert num_iter == 1
  133. # 4. train+unlabeled data.
  134. data1 = ds.STL10Dataset(DATA_DIR, usage="train+unlabeled", num_samples=2, shuffle=False)
  135. images, labels = load_stl10(DATA_DIR, "train+unlabeled")
  136. num_iter = 0
  137. # in this example, each dictionary has keys "image" and "label".
  138. for i, d in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
  139. np.testing.assert_array_equal(d["image"], np.transpose(images[i], (1, 2, 0)))
  140. np.testing.assert_array_equal(d["label"], labels[i])
  141. num_iter += 1
  142. assert num_iter == 2
  143. # 4. all data.
  144. data1 = ds.STL10Dataset(DATA_DIR, usage="all", num_samples=3, shuffle=False)
  145. images, labels = load_stl10(DATA_DIR, "all")
  146. num_iter = 0
  147. # in this example, each dictionary has keys "image" and "label".
  148. for i, d in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
  149. np.testing.assert_array_equal(d["image"], np.transpose(images[i], (1, 2, 0)))
  150. np.testing.assert_array_equal(d["label"], labels[i])
  151. num_iter += 1
  152. assert num_iter == 3
  153. def test_stl10_basic():
  154. """
  155. Feature: test_stl10_basic.
  156. Description: test basic usage of STL10Dataset.
  157. Expectation: get correct number of data.
  158. """
  159. logger.info("Test STL10Dataset Op")
  160. # case 1: test loading whole dataset.
  161. all_data = ds.STL10Dataset(DATA_DIR, "all")
  162. num_iter = 0
  163. for _ in all_data.create_dict_iterator(num_epochs=1):
  164. num_iter += 1
  165. assert num_iter == 3
  166. # case 2: test num_samples.
  167. all_data = ds.STL10Dataset(DATA_DIR, "all", num_samples=1)
  168. num_iter = 0
  169. for _ in all_data.create_dict_iterator(num_epochs=1):
  170. num_iter += 1
  171. assert num_iter == 1
  172. # case 3: test repeat.
  173. all_data = ds.STL10Dataset(DATA_DIR, "all", num_samples=2)
  174. all_data = all_data.repeat(5)
  175. num_iter = 0
  176. for _ in all_data.create_dict_iterator(num_epochs=1):
  177. num_iter += 1
  178. assert num_iter == 10
  179. # case 4: test batch with drop_remainder=False.
  180. all_data = ds.STL10Dataset(DATA_DIR, "all", num_samples=2)
  181. assert all_data.get_dataset_size() == 2
  182. assert all_data.get_batch_size() == 1
  183. all_data = all_data.batch(batch_size=2) # drop_remainder is default to be False.
  184. assert all_data.get_batch_size() == 2
  185. assert all_data.get_dataset_size() == 1
  186. num_iter = 0
  187. for _ in all_data.create_dict_iterator(num_epochs=1):
  188. num_iter += 1
  189. assert num_iter == 1
  190. # case 5: test batch with drop_remainder=True.
  191. all_data = ds.STL10Dataset(DATA_DIR, "all", num_samples=2)
  192. assert all_data.get_dataset_size() == 2
  193. assert all_data.get_batch_size() == 1
  194. all_data = all_data.batch(batch_size=2, drop_remainder=True) # the rest of incomplete batch will be dropped.
  195. assert all_data.get_dataset_size() == 1
  196. assert all_data.get_batch_size() == 2
  197. num_iter = 0
  198. for _ in all_data.create_dict_iterator(num_epochs=1):
  199. num_iter += 1
  200. assert num_iter == 1
  201. def test_stl10_sequential_sampler():
  202. """
  203. Feature: test_stl10_sequential_sampler.
  204. Description: test usage of STL10Dataset with SequentialSampler.
  205. Expectation: get correct number of data.
  206. """
  207. logger.info("Test STL10Dataset Op with SequentialSampler")
  208. num_samples = 2
  209. sampler = ds.SequentialSampler(num_samples=num_samples)
  210. all_data_1 = ds.STL10Dataset(DATA_DIR, "all", sampler=sampler)
  211. all_data_2 = ds.STL10Dataset(DATA_DIR, "all", shuffle=False, num_samples=num_samples)
  212. label_list_1, label_list_2 = [], []
  213. num_iter = 0
  214. for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1),
  215. all_data_2.create_dict_iterator(num_epochs=1)):
  216. label_list_1.append(item1["label"].asnumpy())
  217. label_list_2.append(item2["label"].asnumpy())
  218. num_iter += 1
  219. np.testing.assert_array_equal(label_list_1, label_list_2)
  220. assert num_iter == num_samples
  221. def test_stl10_exception():
  222. """
  223. Feature: test_stl10_exception.
  224. Description: test error cases for STL10Dataset.
  225. Expectation: raise exception.
  226. """
  227. logger.info("Test error cases for STL10Dataset")
  228. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  229. with pytest.raises(RuntimeError, match=error_msg_1):
  230. ds.STL10Dataset(DATA_DIR, "all", shuffle=False, sampler=ds.PKSampler(3))
  231. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  232. with pytest.raises(RuntimeError, match=error_msg_2):
  233. ds.STL10Dataset(DATA_DIR, "all", sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
  234. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  235. with pytest.raises(RuntimeError, match=error_msg_3):
  236. ds.STL10Dataset(DATA_DIR, "all", num_shards=10)
  237. error_msg_4 = "shard_id is specified but num_shards is not"
  238. with pytest.raises(RuntimeError, match=error_msg_4):
  239. ds.STL10Dataset(DATA_DIR, "all", shard_id=0)
  240. error_msg_5 = "Input shard_id is not within the required interval"
  241. with pytest.raises(ValueError, match=error_msg_5):
  242. ds.STL10Dataset(DATA_DIR, "all", num_shards=5, shard_id=-1)
  243. with pytest.raises(ValueError, match=error_msg_5):
  244. ds.STL10Dataset(DATA_DIR, "all", num_shards=5, shard_id=5)
  245. with pytest.raises(ValueError, match=error_msg_5):
  246. ds.STL10Dataset(DATA_DIR, "all", num_shards=2, shard_id=5)
  247. error_msg_6 = "num_parallel_workers exceeds"
  248. with pytest.raises(ValueError, match=error_msg_6):
  249. ds.STL10Dataset(DATA_DIR, "all", shuffle=False, num_parallel_workers=0)
  250. with pytest.raises(ValueError, match=error_msg_6):
  251. ds.STL10Dataset(DATA_DIR, "all", shuffle=False, num_parallel_workers=256)
  252. with pytest.raises(ValueError, match=error_msg_6):
  253. ds.STL10Dataset(DATA_DIR, "all", shuffle=False, num_parallel_workers=-2)
  254. error_msg_7 = "Argument shard_id"
  255. with pytest.raises(TypeError, match=error_msg_7):
  256. ds.STL10Dataset(DATA_DIR, "all", num_shards=2, shard_id="0")
  257. def exception_func(item):
  258. raise Exception("Error occur!")
  259. error_msg_8 = "The corresponding data files"
  260. with pytest.raises(RuntimeError, match=error_msg_8):
  261. all_data = ds.STL10Dataset(DATA_DIR, "all")
  262. all_data = all_data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  263. for _ in all_data.__iter__():
  264. pass
  265. with pytest.raises(RuntimeError, match=error_msg_8):
  266. all_data = ds.STL10Dataset(DATA_DIR, "all")
  267. all_data = all_data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
  268. for _ in all_data.__iter__():
  269. pass
  270. error_msg_9 = "does not exist or permission denied!"
  271. with pytest.raises(ValueError, match=error_msg_9):
  272. all_data = ds.STL10Dataset(WRONG_DIR, "all")
  273. for _ in all_data.__iter__():
  274. pass
  275. def test_stl10_visualize(plot=False):
  276. """
  277. Feature: test_stl10_visualize.
  278. Description: visualize STL10Dataset results.
  279. Expectation: get correct number of data and plot them.
  280. """
  281. logger.info("Test STL10Dataset visualization")
  282. all_data = ds.STL10Dataset(DATA_DIR, "all", num_samples=2, shuffle=False)
  283. num_iter = 0
  284. image_list, label_list = [], []
  285. for item in all_data.create_dict_iterator(num_epochs=1, output_numpy=True):
  286. image = item["image"]
  287. label = item["label"]
  288. image_list.append(image)
  289. label_list.append("label {}".format(label))
  290. assert isinstance(image, np.ndarray)
  291. assert image.shape == (96, 96, 3)
  292. assert image.dtype == np.uint8
  293. assert label.dtype == np.int32
  294. num_iter += 1
  295. assert num_iter == 2
  296. if plot:
  297. visualize_dataset(image_list, label_list)
  298. def test_stl10_usage():
  299. """
  300. Feature: test_stl10_usage.
  301. Description: validate STL10Dataset image readings.
  302. Expectation: get correct number of data.
  303. """
  304. logger.info("Test STL10Dataset usage flag")
  305. def test_config(usage, path=None):
  306. path = DATA_DIR if path is None else path
  307. try:
  308. data = ds.STL10Dataset(path, usage=usage, shuffle=False)
  309. num_rows = 0
  310. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  311. num_rows += 1
  312. except (ValueError, TypeError, RuntimeError) as e:
  313. return str(e)
  314. return num_rows
  315. assert test_config("train") == 1
  316. assert test_config("test") == 1
  317. assert test_config("unlabeled") == 1
  318. assert test_config("train+unlabeled") == 2
  319. assert test_config("all") == 3
  320. assert "Input usage is not within the valid set of ['train', 'test', 'unlabeled', 'train+unlabeled', 'all']."\
  321. in test_config("invalid")
  322. assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
  323. # change this directory to the folder that contains all STL10 files.
  324. all_files_path = None
  325. # the following tests on the entire datasets.
  326. if all_files_path is not None:
  327. assert test_config("train", all_files_path) == 1
  328. assert ds.STL10Dataset(all_files_path, usage="train").get_dataset_size() == 1
  329. if __name__ == '__main__':
  330. test_stl10_content_check()
  331. test_stl10_basic()
  332. test_stl10_sequential_sampler()
  333. test_stl10_exception()
  334. test_stl10_visualize(plot=True)
  335. test_stl10_usage()