You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_usps.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Test USPS dataset operators
  17. """
  18. import os
  19. from typing import cast
  20. import matplotlib.pyplot as plt
  21. import numpy as np
  22. import pytest
  23. import mindspore.dataset as ds
  24. import mindspore.dataset.vision.c_transforms as vision
  25. from mindspore import log as logger
  26. DATA_DIR = "../data/dataset/testUSPSDataset"
  27. WRONG_DIR = "../data/dataset/testMnistData"
  28. def load_usps(path, usage):
  29. """
  30. load USPS data
  31. """
  32. assert usage in ["train", "test"]
  33. if usage == "train":
  34. data_path = os.path.realpath(os.path.join(path, "usps"))
  35. elif usage == "test":
  36. data_path = os.path.realpath(os.path.join(path, "usps.t"))
  37. with open(data_path, 'r') as f:
  38. raw_data = [line.split() for line in f.readlines()]
  39. tmp_list = [[x.split(':')[-1] for x in data[1:]] for data in raw_data]
  40. images = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16, 1))
  41. images = ((cast(np.ndarray, images) + 1) / 2 * 255).astype(dtype=np.uint8)
  42. labels = [int(d[0]) - 1 for d in raw_data]
  43. return images, labels
  44. def visualize_dataset(images, labels):
  45. """
  46. Helper function to visualize the dataset samples
  47. """
  48. num_samples = len(images)
  49. for i in range(num_samples):
  50. plt.subplot(1, num_samples, i + 1)
  51. plt.imshow(images[i].squeeze(), cmap=plt.cm.gray)
  52. plt.title(labels[i])
  53. plt.show()
  54. def test_usps_content_check():
  55. """
  56. Validate USPSDataset image readings
  57. """
  58. logger.info("Test USPSDataset Op with content check")
  59. train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=10, shuffle=False)
  60. images, labels = load_usps(DATA_DIR, "train")
  61. num_iter = 0
  62. # in this example, each dictionary has keys "image" and "label"
  63. for i, data in enumerate(train_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  64. for m in range(16):
  65. for n in range(16):
  66. assert (data["image"][m, n, 0] != 0 or images[i][m, n, 0] != 255) and \
  67. (data["image"][m, n, 0] != 255 or images[i][m, n, 0] != 0)
  68. assert (data["image"][m, n, 0] == images[i][m, n, 0]) or\
  69. (data["image"][m, n, 0] == images[i][m, n, 0] + 1) or\
  70. (data["image"][m, n, 0] + 1 == images[i][m, n, 0])
  71. np.testing.assert_array_equal(data["label"], labels[i])
  72. num_iter += 1
  73. assert num_iter == 3
  74. test_data = ds.USPSDataset(DATA_DIR, "test", num_samples=3, shuffle=False)
  75. images, labels = load_usps(DATA_DIR, "test")
  76. num_iter = 0
  77. # in this example, each dictionary has keys "image" and "label"
  78. for i, data in enumerate(test_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  79. for m in range(16):
  80. for n in range(16):
  81. if (data["image"][m, n, 0] == 0 and images[i][m, n, 0] == 255) or\
  82. (data["image"][m, n, 0] == 255 and images[i][m, n, 0] == 0):
  83. assert False
  84. if (data["image"][m, n, 0] != images[i][m, n, 0]) and\
  85. (data["image"][m, n, 0] != images[i][m, n, 0] + 1) and\
  86. (data["image"][m, n, 0] + 1 != images[i][m, n, 0]):
  87. assert False
  88. np.testing.assert_array_equal(data["label"], labels[i])
  89. num_iter += 1
  90. assert num_iter == 3
  91. def test_usps_basic():
  92. """
  93. Validate USPSDataset
  94. """
  95. logger.info("Test USPSDataset Op")
  96. # case 1: test loading whole dataset
  97. train_data = ds.USPSDataset(DATA_DIR, "train")
  98. num_iter = 0
  99. for _ in train_data.create_dict_iterator(num_epochs=1):
  100. num_iter += 1
  101. assert num_iter == 3
  102. test_data = ds.USPSDataset(DATA_DIR, "test")
  103. num_iter = 0
  104. for _ in test_data.create_dict_iterator(num_epochs=1):
  105. num_iter += 1
  106. assert num_iter == 3
  107. # case 2: test num_samples
  108. train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=2)
  109. num_iter = 0
  110. for _ in train_data.create_dict_iterator(num_epochs=1):
  111. num_iter += 1
  112. assert num_iter == 2
  113. # case 3: test repeat
  114. train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=2)
  115. train_data = train_data.repeat(5)
  116. num_iter = 0
  117. for _ in train_data.create_dict_iterator(num_epochs=1):
  118. num_iter += 1
  119. assert num_iter == 10
  120. # case 4: test batch with drop_remainder=False
  121. train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=3)
  122. assert train_data.get_dataset_size() == 3
  123. assert train_data.get_batch_size() == 1
  124. train_data = train_data.batch(batch_size=2) # drop_remainder is default to be False
  125. assert train_data.get_batch_size() == 2
  126. assert train_data.get_dataset_size() == 2
  127. num_iter = 0
  128. for _ in train_data.create_dict_iterator(num_epochs=1):
  129. num_iter += 1
  130. assert num_iter == 2
  131. # case 5: test batch with drop_remainder=True
  132. train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=3)
  133. assert train_data.get_dataset_size() == 3
  134. assert train_data.get_batch_size() == 1
  135. train_data = train_data.batch(batch_size=2, drop_remainder=True) # the rest of incomplete batch will be dropped
  136. assert train_data.get_dataset_size() == 1
  137. assert train_data.get_batch_size() == 2
  138. num_iter = 0
  139. for _ in train_data.create_dict_iterator(num_epochs=1):
  140. num_iter += 1
  141. assert num_iter == 1
  142. def test_usps_exception():
  143. """
  144. Test error cases for USPSDataset
  145. """
  146. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  147. with pytest.raises(RuntimeError, match=error_msg_3):
  148. ds.USPSDataset(DATA_DIR, "train", num_shards=10)
  149. ds.USPSDataset(DATA_DIR, "test", num_shards=10)
  150. error_msg_4 = "shard_id is specified but num_shards is not"
  151. with pytest.raises(RuntimeError, match=error_msg_4):
  152. ds.USPSDataset(DATA_DIR, "train", shard_id=0)
  153. ds.USPSDataset(DATA_DIR, "test", shard_id=0)
  154. error_msg_5 = "Input shard_id is not within the required interval"
  155. with pytest.raises(ValueError, match=error_msg_5):
  156. ds.USPSDataset(DATA_DIR, "train", num_shards=5, shard_id=-1)
  157. ds.USPSDataset(DATA_DIR, "test", num_shards=5, shard_id=-1)
  158. with pytest.raises(ValueError, match=error_msg_5):
  159. ds.USPSDataset(DATA_DIR, "train", num_shards=5, shard_id=5)
  160. ds.USPSDataset(DATA_DIR, "test", num_shards=5, shard_id=5)
  161. with pytest.raises(ValueError, match=error_msg_5):
  162. ds.USPSDataset(DATA_DIR, "train", num_shards=2, shard_id=5)
  163. ds.USPSDataset(DATA_DIR, "test", num_shards=2, shard_id=5)
  164. error_msg_6 = "num_parallel_workers exceeds"
  165. with pytest.raises(ValueError, match=error_msg_6):
  166. ds.USPSDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=0)
  167. ds.USPSDataset(DATA_DIR, "test", shuffle=False, num_parallel_workers=0)
  168. with pytest.raises(ValueError, match=error_msg_6):
  169. ds.USPSDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=256)
  170. ds.USPSDataset(DATA_DIR, "test", shuffle=False, num_parallel_workers=256)
  171. with pytest.raises(ValueError, match=error_msg_6):
  172. ds.USPSDataset(DATA_DIR, "train", shuffle=False, num_parallel_workers=-2)
  173. ds.USPSDataset(DATA_DIR, "test", shuffle=False, num_parallel_workers=-2)
  174. error_msg_7 = "Argument shard_id"
  175. with pytest.raises(TypeError, match=error_msg_7):
  176. ds.USPSDataset(DATA_DIR, "train", num_shards=2, shard_id="0")
  177. ds.USPSDataset(DATA_DIR, "test", num_shards=2, shard_id="0")
  178. error_msg_8 = "invalid input shape"
  179. with pytest.raises(RuntimeError, match=error_msg_8):
  180. train_data = ds.USPSDataset(DATA_DIR, "train")
  181. train_data = train_data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
  182. for _ in train_data.__iter__():
  183. pass
  184. test_data = ds.USPSDataset(DATA_DIR, "test")
  185. test_data = test_data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
  186. for _ in test_data.__iter__():
  187. pass
  188. error_msg_9 = "failed to find USPS train data file"
  189. with pytest.raises(RuntimeError, match=error_msg_9):
  190. train_data = ds.USPSDataset(WRONG_DIR, "train")
  191. for _ in train_data.__iter__():
  192. pass
  193. error_msg_10 = "failed to find USPS test data file"
  194. with pytest.raises(RuntimeError, match=error_msg_10):
  195. test_data = ds.USPSDataset(WRONG_DIR, "test")
  196. for _ in test_data.__iter__():
  197. pass
  198. def test_usps_visualize(plot=False):
  199. """
  200. Visualize USPSDataset results
  201. """
  202. logger.info("Test USPSDataset visualization")
  203. train_data = ds.USPSDataset(DATA_DIR, "train", num_samples=3, shuffle=False)
  204. num_iter = 0
  205. image_list, label_list = [], []
  206. for item in train_data.create_dict_iterator(num_epochs=1, output_numpy=True):
  207. image = item["image"]
  208. label = item["label"]
  209. image_list.append(image)
  210. label_list.append("label {}".format(label))
  211. assert isinstance(image, np.ndarray)
  212. assert image.shape == (16, 16, 1)
  213. assert image.dtype == np.uint8
  214. assert label.dtype == np.uint32
  215. num_iter += 1
  216. assert num_iter == 3
  217. if plot:
  218. visualize_dataset(image_list, label_list)
  219. test_data = ds.USPSDataset(DATA_DIR, "test", num_samples=3, shuffle=False)
  220. num_iter = 0
  221. image_list, label_list = [], []
  222. for item in test_data.create_dict_iterator(num_epochs=1, output_numpy=True):
  223. image = item["image"]
  224. label = item["label"]
  225. image_list.append(image)
  226. label_list.append("label {}".format(label))
  227. assert isinstance(image, np.ndarray)
  228. assert image.shape == (16, 16, 1)
  229. assert image.dtype == np.uint8
  230. assert label.dtype == np.uint32
  231. num_iter += 1
  232. assert num_iter == 3
  233. if plot:
  234. visualize_dataset(image_list, label_list)
  235. def test_usps_usage():
  236. """
  237. Validate USPSDataset image readings
  238. """
  239. logger.info("Test USPSDataset usage flag")
  240. def test_config(usage, path=None):
  241. path = DATA_DIR if path is None else path
  242. try:
  243. data = ds.USPSDataset(path, usage=usage, shuffle=False)
  244. num_rows = 0
  245. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  246. num_rows += 1
  247. except (ValueError, TypeError, RuntimeError) as e:
  248. return str(e)
  249. return num_rows
  250. assert test_config("train") == 3
  251. assert test_config("test") == 3
  252. assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid")
  253. assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
  254. # change this directory to the folder that contains all USPS files
  255. all_files_path = None
  256. # the following tests on the entire datasets
  257. if all_files_path is not None:
  258. assert test_config("train", all_files_path) == 3
  259. assert test_config("test", all_files_path) == 3
  260. assert ds.USPSDataset(all_files_path, usage="train").get_dataset_size() == 3
  261. assert ds.USPSDataset(all_files_path, usage="test").get_dataset_size() == 3
  262. if __name__ == '__main__':
  263. test_usps_content_check()
  264. test_usps_basic()
  265. test_usps_exception()
  266. test_usps_visualize(plot=True)
  267. test_usps_usage()