You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_qmnist.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Test QMnistDataset operator
  17. """
  18. import os
  19. import matplotlib.pyplot as plt
  20. import numpy as np
  21. import pytest
  22. import mindspore.dataset as ds
  23. import mindspore.dataset.vision.c_transforms as vision
  24. from mindspore import log as logger
  25. DATA_DIR = "../data/dataset/testQMnistData"
  26. def load_qmnist(path, usage, compat=True):
  27. """
  28. load QMNIST data
  29. """
  30. image_path = []
  31. label_path = []
  32. image_ext = "images-idx3-ubyte"
  33. label_ext = "labels-idx2-int"
  34. train_prefix = "qmnist-train"
  35. test_prefix = "qmnist-test"
  36. nist_prefix = "xnist"
  37. assert usage in ["train", "test", "nist", "all"]
  38. if usage == "train":
  39. image_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + image_ext)))
  40. label_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + label_ext)))
  41. elif usage == "test":
  42. image_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + image_ext)))
  43. label_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + label_ext)))
  44. elif usage == "nist":
  45. image_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + image_ext)))
  46. label_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + label_ext)))
  47. elif usage == "all":
  48. image_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + image_ext)))
  49. label_path.append(os.path.realpath(os.path.join(path, train_prefix + "-" + label_ext)))
  50. image_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + image_ext)))
  51. label_path.append(os.path.realpath(os.path.join(path, test_prefix + "-" + label_ext)))
  52. image_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + image_ext)))
  53. label_path.append(os.path.realpath(os.path.join(path, nist_prefix + "-" + label_ext)))
  54. assert len(image_path) == len(label_path)
  55. images = []
  56. labels = []
  57. for i, _ in enumerate(image_path):
  58. with open(image_path[i], 'rb') as image_file:
  59. image_file.read(16)
  60. image = np.fromfile(image_file, dtype=np.uint8)
  61. image = image.reshape(-1, 28, 28, 1)
  62. image[image > 0] = 255 # Perform binarization to maintain consistency with our API
  63. images.append(image)
  64. with open(label_path[i], 'rb') as label_file:
  65. label_file.read(12)
  66. label = np.fromfile(label_file, dtype='>u4')
  67. label = label.reshape(-1, 8)
  68. labels.append(label)
  69. images = np.concatenate(images, 0)
  70. labels = np.concatenate(labels, 0)
  71. if compat:
  72. return images, labels[:, 0]
  73. return images, labels
  74. def visualize_dataset(images, labels):
  75. """
  76. Helper function to visualize the dataset samples
  77. """
  78. num_samples = len(images)
  79. for i in range(num_samples):
  80. plt.subplot(1, num_samples, i + 1)
  81. plt.imshow(images[i].squeeze(), cmap=plt.cm.gray)
  82. plt.title(labels[i])
  83. plt.show()
  84. def test_qmnist_content_check():
  85. """
  86. Validate QMnistDataset image readings
  87. """
  88. logger.info("Test QMnistDataset Op with content check")
  89. for usage in ["train", "test", "nist", "all"]:
  90. data1 = ds.QMnistDataset(DATA_DIR, usage, True, num_samples=10, shuffle=False)
  91. images, labels = load_qmnist(DATA_DIR, usage, True)
  92. num_iter = 0
  93. # in this example, each dictionary has keys "image" and "label"
  94. image_list, label_list = [], []
  95. for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
  96. image_list.append(data["image"])
  97. label_list.append("label {}".format(data["label"]))
  98. np.testing.assert_array_equal(data["image"], images[i])
  99. np.testing.assert_array_equal(data["label"], labels[i])
  100. num_iter += 1
  101. assert num_iter == 10
  102. for usage in ["train", "test", "nist", "all"]:
  103. data1 = ds.QMnistDataset(DATA_DIR, usage, False, num_samples=10, shuffle=False)
  104. images, labels = load_qmnist(DATA_DIR, usage, False)
  105. num_iter = 0
  106. # in this example, each dictionary has keys "image" and "label"
  107. image_list, label_list = [], []
  108. for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
  109. image_list.append(data["image"])
  110. label_list.append("label {}".format(data["label"]))
  111. np.testing.assert_array_equal(data["image"], images[i])
  112. np.testing.assert_array_equal(data["label"], labels[i])
  113. num_iter += 1
  114. assert num_iter == 10
  115. def test_qmnist_basic():
  116. """
  117. Validate QMnistDataset
  118. """
  119. logger.info("Test QMnistDataset Op")
  120. # case 1: test loading whole dataset
  121. data1 = ds.QMnistDataset(DATA_DIR, "train", True)
  122. num_iter1 = 0
  123. for _ in data1.create_dict_iterator(num_epochs=1):
  124. num_iter1 += 1
  125. assert num_iter1 == 10
  126. # case 2: test num_samples
  127. data2 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=5)
  128. num_iter2 = 0
  129. for _ in data2.create_dict_iterator(num_epochs=1):
  130. num_iter2 += 1
  131. assert num_iter2 == 5
  132. # case 3: test repeat
  133. data3 = ds.QMnistDataset(DATA_DIR, "train", True)
  134. data3 = data3.repeat(5)
  135. num_iter3 = 0
  136. for _ in data3.create_dict_iterator(num_epochs=1):
  137. num_iter3 += 1
  138. assert num_iter3 == 50
  139. # case 4: test batch with drop_remainder=False
  140. data4 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10)
  141. assert data4.get_dataset_size() == 10
  142. assert data4.get_batch_size() == 1
  143. data4 = data4.batch(batch_size=7) # drop_remainder is default to be False
  144. assert data4.get_dataset_size() == 2
  145. assert data4.get_batch_size() == 7
  146. num_iter4 = 0
  147. for _ in data4.create_dict_iterator(num_epochs=1):
  148. num_iter4 += 1
  149. assert num_iter4 == 2
  150. # case 5: test batch with drop_remainder=True
  151. data5 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10)
  152. assert data5.get_dataset_size() == 10
  153. assert data5.get_batch_size() == 1
  154. data5 = data5.batch(batch_size=3, drop_remainder=True) # the rest of incomplete batch will be dropped
  155. assert data5.get_dataset_size() == 3
  156. assert data5.get_batch_size() == 3
  157. num_iter5 = 0
  158. for _ in data5.create_dict_iterator(num_epochs=1):
  159. num_iter5 += 1
  160. assert num_iter5 == 3
  161. # case 6: test get_col_names
  162. dataset = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10)
  163. assert dataset.get_col_names() == ["image", "label"]
  164. def test_qmnist_pk_sampler():
  165. """
  166. Test QMnistDataset with PKSampler
  167. """
  168. logger.info("Test QMnistDataset Op with PKSampler")
  169. golden = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  170. sampler = ds.PKSampler(10)
  171. data = ds.QMnistDataset(DATA_DIR, "nist", True, sampler=sampler)
  172. num_iter = 0
  173. label_list = []
  174. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  175. label_list.append(item["label"])
  176. num_iter += 1
  177. np.testing.assert_array_equal(golden, label_list)
  178. assert num_iter == 10
  179. def test_qmnist_sequential_sampler():
  180. """
  181. Test QMnistDataset with SequentialSampler
  182. """
  183. logger.info("Test QMnistDataset Op with SequentialSampler")
  184. num_samples = 10
  185. sampler = ds.SequentialSampler(num_samples=num_samples)
  186. data1 = ds.QMnistDataset(DATA_DIR, "train", True, sampler=sampler)
  187. data2 = ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_samples=num_samples)
  188. label_list1, label_list2 = [], []
  189. num_iter = 0
  190. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)):
  191. label_list1.append(item1["label"].asnumpy())
  192. label_list2.append(item2["label"].asnumpy())
  193. num_iter += 1
  194. np.testing.assert_array_equal(label_list1, label_list2)
  195. assert num_iter == num_samples
  196. def test_qmnist_exception():
  197. """
  198. Test error cases for QMnistDataset
  199. """
  200. logger.info("Test error cases for MnistDataset")
  201. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  202. with pytest.raises(RuntimeError, match=error_msg_1):
  203. ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, sampler=ds.PKSampler(3))
  204. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  205. with pytest.raises(RuntimeError, match=error_msg_2):
  206. ds.QMnistDataset(DATA_DIR, "nist", True, sampler=ds.PKSampler(1), num_shards=2, shard_id=0)
  207. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  208. with pytest.raises(RuntimeError, match=error_msg_3):
  209. ds.QMnistDataset(DATA_DIR, "train", True, num_shards=10)
  210. error_msg_4 = "shard_id is specified but num_shards is not"
  211. with pytest.raises(RuntimeError, match=error_msg_4):
  212. ds.QMnistDataset(DATA_DIR, "train", True, shard_id=0)
  213. error_msg_5 = "Input shard_id is not within the required interval"
  214. with pytest.raises(ValueError, match=error_msg_5):
  215. ds.QMnistDataset(DATA_DIR, "train", True, num_shards=5, shard_id=-1)
  216. with pytest.raises(ValueError, match=error_msg_5):
  217. ds.QMnistDataset(DATA_DIR, "train", True, num_shards=5, shard_id=5)
  218. with pytest.raises(ValueError, match=error_msg_5):
  219. ds.QMnistDataset(DATA_DIR, "train", True, num_shards=2, shard_id=5)
  220. error_msg_6 = "num_parallel_workers exceeds"
  221. with pytest.raises(ValueError, match=error_msg_6):
  222. ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=0)
  223. with pytest.raises(ValueError, match=error_msg_6):
  224. ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=256)
  225. with pytest.raises(ValueError, match=error_msg_6):
  226. ds.QMnistDataset(DATA_DIR, "train", True, shuffle=False, num_parallel_workers=-2)
  227. error_msg_7 = "Argument shard_id"
  228. with pytest.raises(TypeError, match=error_msg_7):
  229. ds.QMnistDataset(DATA_DIR, "train", True, num_shards=2, shard_id="0")
  230. def exception_func(item):
  231. raise Exception("Error occur!")
  232. error_msg_8 = "The corresponding data files"
  233. with pytest.raises(RuntimeError, match=error_msg_8):
  234. data = ds.QMnistDataset(DATA_DIR, "train", True)
  235. data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  236. for _ in data.__iter__():
  237. pass
  238. with pytest.raises(RuntimeError, match=error_msg_8):
  239. data = ds.QMnistDataset(DATA_DIR, "train", True)
  240. data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
  241. data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  242. for _ in data.__iter__():
  243. pass
  244. with pytest.raises(RuntimeError, match=error_msg_8):
  245. data = ds.QMnistDataset(DATA_DIR, "train", True)
  246. data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
  247. for _ in data.__iter__():
  248. pass
  249. def test_qmnist_visualize(plot=False):
  250. """
  251. Visualize QMnistDataset results
  252. """
  253. logger.info("Test QMnistDataset visualization")
  254. data1 = ds.QMnistDataset(DATA_DIR, "train", True, num_samples=10, shuffle=False)
  255. num_iter = 0
  256. image_list, label_list = [], []
  257. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  258. image = item["image"]
  259. label = item["label"]
  260. image_list.append(image)
  261. label_list.append("label {}".format(label))
  262. assert isinstance(image, np.ndarray)
  263. assert image.shape == (28, 28, 1)
  264. assert image.dtype == np.uint8
  265. assert label.dtype == np.uint32
  266. num_iter += 1
  267. assert num_iter == 10
  268. if plot:
  269. visualize_dataset(image_list, label_list)
  270. def test_qmnist_usage():
  271. """
  272. Validate QMnistDataset image readings
  273. """
  274. logger.info("Test QMnistDataset usage flag")
  275. def test_config(usage, path=None):
  276. path = DATA_DIR if path is None else path
  277. try:
  278. data = ds.QMnistDataset(path, usage=usage, compat=True, shuffle=False)
  279. num_rows = 0
  280. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  281. num_rows += 1
  282. except (ValueError, TypeError, RuntimeError) as e:
  283. return str(e)
  284. return num_rows
  285. assert test_config("train") == 10
  286. assert test_config("test") == 10
  287. assert test_config("nist") == 10
  288. assert test_config("all") == 30
  289. assert "usage is not within the valid set of ['train', 'test', 'test10k', 'test50k', 'nist', 'all']" in\
  290. test_config("invalid")
  291. assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
  292. if __name__ == '__main__':
  293. test_qmnist_content_check()
  294. test_qmnist_basic()
  295. test_qmnist_pk_sampler()
  296. test_qmnist_sequential_sampler()
  297. test_qmnist_exception()
  298. test_qmnist_visualize(plot=True)
  299. test_qmnist_usage()