You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_svhn.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Test SVHN dataset operators
  17. """
  18. import os
  19. import matplotlib.pyplot as plt
  20. import numpy as np
  21. import pytest
  22. from scipy.io import loadmat
  23. import mindspore.dataset as ds
  24. from mindspore import log as logger
  25. DATA_DIR = "../data/dataset/testSVHNData"
  26. WRONG_DIR = "../data/dataset/testMnistData"
  27. def load_mat(mode, path):
  28. """
  29. Feature: load_mat.
  30. Description: load .mat file.
  31. Expectation: get .mat of svhn dataset.
  32. """
  33. filename = mode + "_32x32.mat"
  34. mat_data = loadmat(os.path.realpath(os.path.join(path, filename)))
  35. data = np.transpose(mat_data['X'], [3, 0, 1, 2])
  36. label = mat_data['y'].astype(np.uint32).squeeze()
  37. np.place(label, label == 10, 0)
  38. return data, label
  39. def load_svhn(path, usage):
  40. """
  41. Feature: load_svhn.
  42. Description: load svhn.
  43. Expectation: get data of svhn dataset.
  44. """
  45. assert usage in ["train", "test", "extra", "all"]
  46. usage_all = ["train", "test", "extra"]
  47. data = np.array([], dtype=np.uint8)
  48. label = np.array([], dtype=np.uint32)
  49. if usage == "all":
  50. for _usage in usage_all:
  51. current_data, current_label = load_mat(_usage, path)
  52. data = np.concatenate((data, current_data)) if data.size else current_data
  53. label = np.concatenate((label, current_label)) if label.size else current_label
  54. else:
  55. data, label = load_mat(usage, path)
  56. return data, label
  57. def visualize_dataset(images, labels):
  58. """
  59. Feature: visualize_dataset.
  60. Description: visualize svhn dataset.
  61. Expectation: plot images.
  62. """
  63. num_samples = len(images)
  64. for i in range(num_samples):
  65. plt.subplot(1, num_samples, i + 1)
  66. plt.imshow(images[i])
  67. plt.title(labels[i])
  68. plt.show()
  69. def test_svhn_content_check():
  70. """
  71. Feature: test_svhn_content_check.
  72. Description: validate SVHNDataset image readings.
  73. Expectation: get correct number of data and correct content.
  74. """
  75. logger.info("Test SVHNDataset Op with content check")
  76. train_data = ds.SVHNDataset(DATA_DIR, "train", num_samples=2, shuffle=False)
  77. images, labels = load_svhn(DATA_DIR, "train")
  78. num_iter = 0
  79. # in this example, each dictionary has keys "image" and "label".
  80. for i, data in enumerate(train_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  81. np.testing.assert_array_equal(data["image"], images[i])
  82. np.testing.assert_array_equal(data["label"], labels[i])
  83. num_iter += 1
  84. assert num_iter == 2
  85. test_data = ds.SVHNDataset(DATA_DIR, "test", num_samples=4, shuffle=False)
  86. images, labels = load_svhn(DATA_DIR, "test")
  87. num_iter = 0
  88. # in this example, each dictionary has keys "image" and "label".
  89. for i, data in enumerate(test_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  90. np.testing.assert_array_equal(data["image"], images[i])
  91. np.testing.assert_array_equal(data["label"], labels[i])
  92. num_iter += 1
  93. assert num_iter == 4
  94. extra_data = ds.SVHNDataset(DATA_DIR, "extra", num_samples=6, shuffle=False)
  95. images, labels = load_svhn(DATA_DIR, "extra")
  96. num_iter = 0
  97. # in this example, each dictionary has keys "image" and "label".
  98. for i, data in enumerate(extra_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  99. np.testing.assert_array_equal(data["image"], images[i])
  100. np.testing.assert_array_equal(data["label"], labels[i])
  101. num_iter += 1
  102. assert num_iter == 6
  103. all_data = ds.SVHNDataset(DATA_DIR, "all", num_samples=12, shuffle=False)
  104. images, labels = load_svhn(DATA_DIR, "all")
  105. num_iter = 0
  106. # in this example, each dictionary has keys "image" and "label".
  107. for i, data in enumerate(all_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  108. np.testing.assert_array_equal(data["image"], images[i])
  109. np.testing.assert_array_equal(data["label"], labels[i])
  110. num_iter += 1
  111. assert num_iter == 12
  112. def test_svhn_basic():
  113. """
  114. Feature: test_svhn_basic.
  115. Description: test basic usage of SVHNDataset.
  116. Expectation: get correct number of data.
  117. """
  118. logger.info("Test SVHNDataset Op")
  119. # case 1: test loading whole dataset.
  120. default_data = ds.SVHNDataset(DATA_DIR)
  121. num_iter = 0
  122. for _ in default_data.create_dict_iterator(num_epochs=1):
  123. num_iter += 1
  124. assert num_iter == 12
  125. all_data = ds.SVHNDataset(DATA_DIR, "all")
  126. num_iter = 0
  127. for _ in all_data.create_dict_iterator(num_epochs=1):
  128. num_iter += 1
  129. assert num_iter == 12
  130. # case 2: test num_samples.
  131. train_data = ds.SVHNDataset(DATA_DIR, "train", num_samples=2)
  132. num_iter = 0
  133. for _ in train_data.create_dict_iterator(num_epochs=1):
  134. num_iter += 1
  135. assert num_iter == 2
  136. # case 3: test repeat.
  137. train_data = ds.SVHNDataset(DATA_DIR, "train", num_samples=2)
  138. train_data = train_data.repeat(5)
  139. num_iter = 0
  140. for _ in train_data.create_dict_iterator(num_epochs=1):
  141. num_iter += 1
  142. assert num_iter == 10
  143. # case 4: test batch with drop_remainder=False.
  144. train_data = ds.SVHNDataset(DATA_DIR, "train", num_samples=2)
  145. assert train_data.get_dataset_size() == 2
  146. assert train_data.get_batch_size() == 1
  147. train_data = train_data.batch(batch_size=2) # drop_remainder is default to be False.
  148. assert train_data.get_batch_size() == 2
  149. assert train_data.get_dataset_size() == 1
  150. num_iter = 0
  151. for _ in train_data.create_dict_iterator(num_epochs=1):
  152. num_iter += 1
  153. assert num_iter == 1
  154. # case 5: test batch with drop_remainder=True.
  155. train_data = ds.SVHNDataset(DATA_DIR, "train", num_samples=2)
  156. assert train_data.get_dataset_size() == 2
  157. assert train_data.get_batch_size() == 1
  158. train_data = train_data.batch(batch_size=2, drop_remainder=True) # the rest of incomplete batch will be dropped.
  159. assert train_data.get_dataset_size() == 1
  160. assert train_data.get_batch_size() == 2
  161. num_iter = 0
  162. for _ in train_data.create_dict_iterator(num_epochs=1):
  163. num_iter += 1
  164. assert num_iter == 1
  165. # case 6: test num_parallel_workers>1
  166. shared_mem_flag = ds.config.get_enable_shared_mem()
  167. ds.config.set_enable_shared_mem(False)
  168. all_data = ds.SVHNDataset(DATA_DIR, "all", num_parallel_workers=2)
  169. num_iter = 0
  170. for _ in all_data.create_dict_iterator(num_epochs=1):
  171. num_iter += 1
  172. assert num_iter == 12
  173. ds.config.set_enable_shared_mem(shared_mem_flag)
  174. # case 7: test map method
  175. input_columns = ["image"]
  176. image1, image2 = [], []
  177. dataset = ds.SVHNDataset(DATA_DIR, "all")
  178. for data in dataset.create_dict_iterator(output_numpy=True):
  179. image1.extend(data['image'])
  180. operations = [(lambda x: x + x)]
  181. dataset = dataset.map(input_columns=input_columns, operations=operations)
  182. for data in dataset.create_dict_iterator(output_numpy=True):
  183. image2.extend(data['image'])
  184. assert len(image1) == len(image2)
  185. # case 8: test batch
  186. dataset = ds.SVHNDataset(DATA_DIR, "all")
  187. dataset = dataset.batch(batch_size=3)
  188. num_iter = 0
  189. for data in dataset.create_dict_iterator(output_numpy=True):
  190. num_iter += 1
  191. assert num_iter == 4
  192. def test_svhn_sequential_sampler():
  193. """
  194. Feature: test_svhn_sequential_sampler.
  195. Description: test usage of SVHNDataset with SequentialSampler.
  196. Expectation: get correct number of data.
  197. """
  198. logger.info("Test SVHNDataset Op with SequentialSampler")
  199. num_samples = 2
  200. sampler = ds.SequentialSampler(num_samples=num_samples)
  201. train_data_1 = ds.SVHNDataset(DATA_DIR, "train", sampler=sampler)
  202. train_data_2 = ds.SVHNDataset(DATA_DIR, "train", shuffle=False, num_samples=num_samples)
  203. label_list_1, label_list_2 = [], []
  204. num_iter = 0
  205. for item1, item2 in zip(train_data_1.create_dict_iterator(num_epochs=1),
  206. train_data_2.create_dict_iterator(num_epochs=1)):
  207. label_list_1.append(item1["label"].asnumpy())
  208. label_list_2.append(item2["label"].asnumpy())
  209. num_iter += 1
  210. np.testing.assert_array_equal(label_list_1, label_list_2)
  211. assert num_iter == num_samples
  212. def test_svhn_exception():
  213. """
  214. Feature: test_svhn_exception.
  215. Description: test error cases for SVHNDataset.
  216. Expectation: raise exception.
  217. """
  218. logger.info("Test error cases for SVHNDataset")
  219. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  220. with pytest.raises(RuntimeError, match=error_msg_1):
  221. ds.SVHNDataset(DATA_DIR, "train", shuffle=False, sampler=ds.SequentialSampler(1))
  222. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  223. with pytest.raises(RuntimeError, match=error_msg_2):
  224. ds.SVHNDataset(DATA_DIR, "train", sampler=ds.SequentialSampler(1), num_shards=2, shard_id=0)
  225. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  226. with pytest.raises(RuntimeError, match=error_msg_3):
  227. ds.SVHNDataset(DATA_DIR, "train", num_shards=10)
  228. error_msg_4 = "shard_id is specified but num_shards is not"
  229. with pytest.raises(RuntimeError, match=error_msg_4):
  230. ds.SVHNDataset(DATA_DIR, "train", shard_id=0)
  231. error_msg_5 = "Input shard_id is not within the required interval"
  232. with pytest.raises(ValueError, match=error_msg_5):
  233. ds.SVHNDataset(DATA_DIR, "train", num_shards=5, shard_id=-1)
  234. with pytest.raises(ValueError, match=error_msg_5):
  235. ds.SVHNDataset(DATA_DIR, "train", num_shards=5, shard_id=5)
  236. with pytest.raises(ValueError, match=error_msg_5):
  237. ds.SVHNDataset(DATA_DIR, "train", num_shards=2, shard_id=5)
  238. error_msg_6 = "num_parallel_workers exceeds"
  239. with pytest.raises(ValueError, match=error_msg_6):
  240. ds.SVHNDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=0)
  241. with pytest.raises(ValueError, match=error_msg_6):
  242. ds.SVHNDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=256)
  243. with pytest.raises(ValueError, match=error_msg_6):
  244. ds.SVHNDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=-2)
  245. error_msg_7 = "Argument shard_id"
  246. with pytest.raises(TypeError, match=error_msg_7):
  247. ds.SVHNDataset(DATA_DIR, "train", num_shards=2, shard_id="0")
  248. error_msg_8 = "does not exist or permission denied!"
  249. with pytest.raises(ValueError, match=error_msg_8):
  250. train_data = ds.SVHNDataset(WRONG_DIR, "train")
  251. for _ in train_data.__iter__():
  252. pass
  253. def test_svhn_visualize(plot=False):
  254. """
  255. Feature: test_svhn_visualize.
  256. Description: visualize SVHNDataset results.
  257. Expectation: get correct number of data and plot them.
  258. """
  259. logger.info("Test SVHNDataset visualization")
  260. train_data = ds.SVHNDataset(DATA_DIR, "train", num_samples=2, shuffle=False)
  261. num_iter = 0
  262. image_list, label_list = [], []
  263. for item in train_data.create_dict_iterator(num_epochs=1, output_numpy=True):
  264. image = item["image"]
  265. label = item["label"]
  266. image_list.append(image)
  267. label_list.append("label {}".format(label))
  268. assert isinstance(image, np.ndarray)
  269. assert image.shape == (32, 32, 3)
  270. assert image.dtype == np.uint8
  271. assert label.dtype == np.uint32
  272. num_iter += 1
  273. assert num_iter == 2
  274. if plot:
  275. visualize_dataset(image_list, label_list)
  276. def test_svhn_usage():
  277. """
  278. Feature: test_svhn_usage.
  279. Description: validate SVHNDataset image readings.
  280. Expectation: get correct number of data.
  281. """
  282. logger.info("Test SVHNDataset usage flag")
  283. def test_config(usage, path=None):
  284. path = DATA_DIR if path is None else path
  285. try:
  286. data = ds.SVHNDataset(path, usage=usage, shuffle=False)
  287. num_rows = 0
  288. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  289. num_rows += 1
  290. except (ValueError, TypeError, RuntimeError) as e:
  291. return str(e)
  292. return num_rows
  293. assert test_config("train") == 2
  294. assert test_config("test") == 4
  295. assert test_config("extra") == 6
  296. assert test_config("all") == 12
  297. assert "usage is not within the valid set of ['train', 'test', 'extra', 'all']" in test_config("invalid")
  298. assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
  299. data_path = None
  300. # the following tests on the entire datasets.
  301. if data_path is not None:
  302. assert test_config("train", data_path) == 2
  303. assert test_config("test", data_path) == 4
  304. assert test_config("extra", data_path) == 6
  305. assert test_config("all", data_path) == 12
  306. assert ds.SVHNDataset(data_path, usage="train").get_dataset_size() == 2
  307. assert ds.SVHNDataset(data_path, usage="test").get_dataset_size() == 4
  308. assert ds.SVHNDataset(data_path, usage="extra").get_dataset_size() == 6
  309. assert ds.SVHNDataset(data_path, usage="all").get_dataset_size() == 12
  310. if __name__ == '__main__':
  311. test_svhn_content_check()
  312. test_svhn_basic()
  313. test_svhn_sequential_sampler()
  314. test_svhn_exception()
  315. test_svhn_visualize(plot=True)
  316. test_svhn_usage()