You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_flowers102.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Test Flowers102 dataset operators
  17. """
  18. import os
  19. import matplotlib.pyplot as plt
  20. import numpy as np
  21. import pytest
  22. from PIL import Image
  23. from scipy.io import loadmat
  24. import mindspore.dataset as ds
  25. import mindspore.dataset.vision.c_transforms as c_vision
  26. from mindspore import log as logger
  27. DATA_DIR = "../data/dataset/testFlowers102Dataset"
  28. WRONG_DIR = "../data/dataset/testMnistData"
  29. def load_flowers102(path, usage):
  30. """
  31. load Flowers102 data
  32. """
  33. assert usage in ["train", "valid", "test", "all"]
  34. imagelabels = (loadmat(os.path.join(path, "imagelabels.mat"))["labels"][0] - 1).astype(np.uint32)
  35. split = loadmat(os.path.join(path, "setid.mat"))
  36. if usage == 'train':
  37. indices = split["trnid"][0].tolist()
  38. elif usage == 'test':
  39. indices = split["tstid"][0].tolist()
  40. elif usage == 'valid':
  41. indices = split["valid"][0].tolist()
  42. elif usage == 'all':
  43. indices = split["trnid"][0].tolist()
  44. indices += split["tstid"][0].tolist()
  45. indices += split["valid"][0].tolist()
  46. image_paths = [os.path.join(path, "jpg", "image_" + str(index).zfill(5) + ".jpg") for index in indices]
  47. segmentation_paths = [os.path.join(path, "segmim", "segmim_" + str(index).zfill(5) + ".jpg") for index in indices]
  48. images = [np.asarray(Image.open(path).convert("RGB")) for path in image_paths]
  49. segmentations = [np.asarray(Image.open(path).convert("RGB")) for path in segmentation_paths]
  50. labels = [imagelabels[index - 1] for index in indices]
  51. return images, segmentations, labels
  52. def visualize_dataset(images, labels):
  53. """
  54. Helper function to visualize the dataset samples
  55. """
  56. num_samples = len(images)
  57. for i in range(num_samples):
  58. plt.subplot(1, num_samples, i + 1)
  59. plt.imshow(images[i].squeeze())
  60. plt.title(labels[i])
  61. plt.show()
  62. def test_flowers102_content_check():
  63. """
  64. Validate Flowers102Dataset image readings
  65. """
  66. logger.info("Test Flowers102Dataset Op with content check")
  67. all_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="all",
  68. num_samples=6, decode=True, shuffle=False)
  69. images, segmentations, labels = load_flowers102(DATA_DIR, "all")
  70. num_iter = 0
  71. # in this example, each dictionary has keys "image" and "label"
  72. for i, data in enumerate(all_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  73. np.testing.assert_array_equal(data["image"], images[i])
  74. np.testing.assert_array_equal(data["segmentation"], segmentations[i])
  75. np.testing.assert_array_equal(data["label"], labels[i])
  76. num_iter += 1
  77. assert num_iter == 6
  78. train_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="train",
  79. num_samples=2, decode=True, shuffle=False)
  80. images, segmentations, labels = load_flowers102(DATA_DIR, "train")
  81. num_iter = 0
  82. # in this example, each dictionary has keys "image" and "label"
  83. for i, data in enumerate(train_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  84. np.testing.assert_array_equal(data["image"], images[i])
  85. np.testing.assert_array_equal(data["segmentation"], segmentations[i])
  86. np.testing.assert_array_equal(data["label"], labels[i])
  87. num_iter += 1
  88. assert num_iter == 2
  89. test_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="test",
  90. num_samples=2, decode=True, shuffle=False)
  91. images, segmentations, labels = load_flowers102(DATA_DIR, "test")
  92. num_iter = 0
  93. # in this example, each dictionary has keys "image" and "label"
  94. for i, data in enumerate(test_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  95. np.testing.assert_array_equal(data["image"], images[i])
  96. np.testing.assert_array_equal(data["segmentation"], segmentations[i])
  97. np.testing.assert_array_equal(data["label"], labels[i])
  98. num_iter += 1
  99. assert num_iter == 2
  100. val_data = ds.Flowers102Dataset(DATA_DIR, task="Segmentation", usage="valid",
  101. num_samples=2, decode=True, shuffle=False)
  102. images, segmentations, labels = load_flowers102(DATA_DIR, "valid")
  103. num_iter = 0
  104. # in this example, each dictionary has keys "image" and "label"
  105. for i, data in enumerate(val_data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  106. np.testing.assert_array_equal(data["image"], images[i])
  107. np.testing.assert_array_equal(data["segmentation"], segmentations[i])
  108. np.testing.assert_array_equal(data["label"], labels[i])
  109. num_iter += 1
  110. assert num_iter == 2
  111. def test_flowers102_basic():
  112. """
  113. Validate Flowers102Dataset
  114. """
  115. logger.info("Test Flowers102Dataset Op")
  116. # case 1: test decode
  117. all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, shuffle=False)
  118. all_data_1 = all_data.map(operations=[c_vision.Decode()], input_columns=["image"], num_parallel_workers=1)
  119. all_data_2 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, shuffle=False)
  120. num_iter = 0
  121. for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1, output_numpy=True),
  122. all_data_2.create_dict_iterator(num_epochs=1, output_numpy=True)):
  123. np.testing.assert_array_equal(item1["label"], item2["label"])
  124. num_iter += 1
  125. assert num_iter == 6
  126. # case 2: test num_samples
  127. all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_samples=4)
  128. num_iter = 0
  129. for _ in all_data.create_dict_iterator(num_epochs=1):
  130. num_iter += 1
  131. assert num_iter == 4
  132. # case 3: test repeat
  133. all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_samples=4)
  134. all_data = all_data.repeat(5)
  135. num_iter = 0
  136. for _ in all_data.create_dict_iterator(num_epochs=1):
  137. num_iter += 1
  138. assert num_iter == 20
  139. # case 3: test get_dataset_size, resize and batch
  140. all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, num_samples=4)
  141. all_data = all_data.map(operations=[c_vision.Decode(), c_vision.Resize((224, 224))], input_columns=["image"],
  142. num_parallel_workers=1)
  143. assert all_data.get_dataset_size() == 4
  144. assert all_data.get_batch_size() == 1
  145. all_data = all_data.batch(batch_size=3) # drop_remainder is default to be False
  146. assert all_data.get_batch_size() == 3
  147. assert all_data.get_dataset_size() == 2
  148. num_iter = 0
  149. for _ in all_data.create_dict_iterator(num_epochs=1):
  150. num_iter += 1
  151. assert num_iter == 2
  152. # case 4: test get_class_indexing
  153. all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=False, num_samples=4)
  154. class_indexing = all_data.get_class_indexing()
  155. assert class_indexing["pink primrose"] == 0
  156. assert class_indexing["blackberry lily"] == 101
  157. def test_flowers102_sequential_sampler():
  158. """
  159. Test Flowers102Dataset with SequentialSampler
  160. """
  161. logger.info("Test Flowers102Dataset Op with SequentialSampler")
  162. num_samples = 4
  163. sampler = ds.SequentialSampler(num_samples=num_samples)
  164. all_data_1 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all",
  165. decode=True, sampler=sampler)
  166. all_data_2 = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all",
  167. decode=True, shuffle=False, num_samples=num_samples)
  168. label_list_1, label_list_2 = [], []
  169. num_iter = 0
  170. for item1, item2 in zip(all_data_1.create_dict_iterator(num_epochs=1),
  171. all_data_2.create_dict_iterator(num_epochs=1)):
  172. label_list_1.append(item1["label"].asnumpy())
  173. label_list_2.append(item2["label"].asnumpy())
  174. num_iter += 1
  175. np.testing.assert_array_equal(label_list_1, label_list_2)
  176. assert num_iter == num_samples
  177. def test_flowers102_exception():
  178. """
  179. Test error cases for Flowers102Dataset
  180. """
  181. logger.info("Test error cases for Flowers102Dataset")
  182. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  183. with pytest.raises(RuntimeError, match=error_msg_1):
  184. ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", shuffle=False,
  185. decode=True, sampler=ds.SequentialSampler(1))
  186. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  187. with pytest.raises(RuntimeError, match=error_msg_2):
  188. ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", sampler=ds.SequentialSampler(1),
  189. decode=True, num_shards=2, shard_id=0)
  190. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  191. with pytest.raises(RuntimeError, match=error_msg_3):
  192. ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=10)
  193. error_msg_4 = "shard_id is specified but num_shards is not"
  194. with pytest.raises(RuntimeError, match=error_msg_4):
  195. ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, shard_id=0)
  196. error_msg_5 = "Input shard_id is not within the required interval"
  197. with pytest.raises(ValueError, match=error_msg_5):
  198. ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=5, shard_id=-1)
  199. with pytest.raises(ValueError, match=error_msg_5):
  200. ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=5, shard_id=5)
  201. with pytest.raises(ValueError, match=error_msg_5):
  202. ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=2, shard_id=5)
  203. error_msg_6 = "num_parallel_workers exceeds"
  204. with pytest.raises(ValueError, match=error_msg_6):
  205. ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True,
  206. shuffle=False, num_parallel_workers=0)
  207. with pytest.raises(ValueError, match=error_msg_6):
  208. ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True,
  209. shuffle=False, num_parallel_workers=256)
  210. with pytest.raises(ValueError, match=error_msg_6):
  211. ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True,
  212. shuffle=False, num_parallel_workers=-2)
  213. error_msg_7 = "Argument shard_id"
  214. with pytest.raises(TypeError, match=error_msg_7):
  215. ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=True, num_shards=2, shard_id="0")
  216. error_msg_8 = "does not exist or is not a directory or permission denied!"
  217. with pytest.raises(ValueError, match=error_msg_8):
  218. all_data = ds.Flowers102Dataset(WRONG_DIR, task="Classification", usage="all", decode=True)
  219. for _ in all_data.create_dict_iterator(num_epochs=1):
  220. pass
  221. error_msg_9 = "is not of type"
  222. with pytest.raises(TypeError, match=error_msg_9):
  223. all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", decode=123)
  224. for _ in all_data.create_dict_iterator(num_epochs=1):
  225. pass
  226. def test_flowers102_visualize(plot=False):
  227. """
  228. Visualize Flowers102Dataset results
  229. """
  230. logger.info("Test Flowers102Dataset visualization")
  231. all_data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage="all", num_samples=4,
  232. decode=True, shuffle=False)
  233. num_iter = 0
  234. image_list, label_list = [], []
  235. for item in all_data.create_dict_iterator(num_epochs=1, output_numpy=True):
  236. image = item["image"]
  237. label = item["label"]
  238. image_list.append(image)
  239. label_list.append("label {}".format(label))
  240. assert isinstance(image, np.ndarray)
  241. assert len(image.shape) == 3
  242. assert image.shape[-1] == 3
  243. assert image.dtype == np.uint8
  244. assert label.dtype == np.uint32
  245. num_iter += 1
  246. assert num_iter == 4
  247. if plot:
  248. visualize_dataset(image_list, label_list)
  249. def test_flowers102_usage():
  250. """
  251. Validate Flowers102Dataset usage
  252. """
  253. logger.info("Test Flowers102Dataset usage flag")
  254. def test_config(usage):
  255. try:
  256. data = ds.Flowers102Dataset(DATA_DIR, task="Classification", usage=usage, decode=True, shuffle=False)
  257. num_rows = 0
  258. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  259. num_rows += 1
  260. except (ValueError, TypeError, RuntimeError) as e:
  261. return str(e)
  262. return num_rows
  263. assert test_config("all") == 6
  264. assert test_config("train") == 2
  265. assert test_config("test") == 2
  266. assert test_config("valid") == 2
  267. assert "usage is not within the valid set of ['train', 'valid', 'test', 'all']" in test_config("invalid")
  268. assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
  269. def test_flowers102_task():
  270. """
  271. Validate Flowers102Dataset task
  272. """
  273. logger.info("Test Flowers102Dataset task flag")
  274. def test_config(task):
  275. try:
  276. data = ds.Flowers102Dataset(DATA_DIR, task=task, usage="all", decode=True, shuffle=False)
  277. num_rows = 0
  278. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  279. num_rows += 1
  280. except (ValueError, TypeError, RuntimeError) as e:
  281. return str(e)
  282. return num_rows
  283. assert test_config("Classification") == 6
  284. assert test_config("Segmentation") == 6
  285. assert "Input task is not within the valid set of ['Classification', 'Segmentation']" in test_config("invalid")
  286. assert "Argument task with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
  287. if __name__ == '__main__':
  288. test_flowers102_content_check()
  289. test_flowers102_basic()
  290. test_flowers102_sequential_sampler()
  291. test_flowers102_exception()
  292. test_flowers102_visualize(plot=True)
  293. test_flowers102_usage()
  294. test_flowers102_task()