You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_places365.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Test Places365 dataset operators
  17. """
  18. import os
  19. import matplotlib.pyplot as plt
  20. import numpy as np
  21. import pytest
  22. from PIL import Image
  23. import mindspore.dataset as ds
  24. from mindspore import log as logger
  25. DATA_DIR = "../data/dataset/testPlaces365Data"
  26. def load_places365(path):
  27. """
  28. Feature: load_places365.
  29. Description: load places365.
  30. Expectation: get data of places365 dataset.
  31. """
  32. images_path = os.path.realpath(os.path.join(path, 'val_256'))
  33. labels_path = os.path.realpath(os.path.join(path, 'places365_val.txt'))
  34. images = []
  35. labels = []
  36. with open(labels_path, 'r') as f:
  37. for line in f.readlines():
  38. file_path, label = line.split()
  39. image = np.array(Image.open(images_path + file_path))
  40. label = int(label)
  41. images.append(image)
  42. labels.append(label)
  43. return images, labels
  44. def visualize_dataset(images, labels):
  45. """
  46. Feature: visualize_dataset.
  47. Description: visualize places365 dataset.
  48. Expectation: plot images.
  49. """
  50. num_samples = len(images)
  51. for i in range(num_samples):
  52. plt.subplot(1, num_samples, i + 1)
  53. plt.imshow(images[i].squeeze(), cmap=plt.cm.gray)
  54. plt.title(labels[i])
  55. plt.show()
  56. def test_places365_content_check():
  57. """
  58. Feature: test_places365_content_check.
  59. Description: validate Places365Dataset image readings.
  60. Expectation: get correct number of data and correct content.
  61. """
  62. logger.info("Test Places365Dataset Op with content check")
  63. sampler = ds.SequentialSampler(num_samples=4)
  64. data1 = ds.Places365Dataset(dataset_dir=DATA_DIR, usage='val', small=True, decode=True, sampler=sampler)
  65. _, labels = load_places365(DATA_DIR)
  66. num_iter = 0
  67. # in this example, each dictionary has keys "image" and "label"
  68. image_list, label_list = [], []
  69. for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
  70. image_list.append(data["image"])
  71. label_list.append("label {}".format(data["label"]))
  72. # due to the precision problem, the following two doesn't total equal.
  73. # np.testing.assert_array_equal(data["image"], images[i])
  74. np.testing.assert_array_equal(data["label"], labels[i])
  75. num_iter += 1
  76. assert num_iter == 4
  77. def test_places365_basic():
  78. """
  79. Feature: test_places365_basic.
  80. Description: test basic usage of Places365Dataset.
  81. Expectation: get correct number of data.
  82. """
  83. logger.info("Test places365Dataset Op")
  84. # case 1: test loading whole dataset
  85. data1 = ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True)
  86. num_iter1 = 0
  87. for _ in data1.create_dict_iterator(num_epochs=1):
  88. num_iter1 += 1
  89. assert num_iter1 == 4
  90. # case 2: test num_samples
  91. data2 = ds.Places365Dataset(DATA_DIR, usage='train-standard', small=True, decode=True, num_samples=4)
  92. num_iter2 = 0
  93. for _ in data2.create_dict_iterator(num_epochs=1):
  94. num_iter2 += 1
  95. assert num_iter2 == 4
  96. # case 3: test repeat
  97. data3 = ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True, num_samples=4)
  98. data3 = data3.repeat(5)
  99. num_iter3 = 0
  100. for _ in data3.create_dict_iterator(num_epochs=1):
  101. num_iter3 += 1
  102. assert num_iter3 == 20
  103. # case 4: test batch with drop_remainder=False
  104. data4 = ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True, num_samples=4)
  105. assert data4.get_dataset_size() == 4
  106. assert data4.get_batch_size() == 1
  107. data4 = data4.batch(batch_size=2) # drop_remainder is default to be False
  108. assert data4.get_dataset_size() == 2
  109. assert data4.get_batch_size() == 2
  110. num_iter4 = 0
  111. for _ in data4.create_dict_iterator(num_epochs=1):
  112. num_iter4 += 1
  113. assert num_iter4 == 2
  114. # case 5: test batch with drop_remainder=True
  115. data5 = ds.Places365Dataset(DATA_DIR, usage='train-standard', small=True, decode=True, num_samples=4)
  116. assert data5.get_dataset_size() == 4
  117. assert data5.get_batch_size() == 1
  118. data5 = data5.batch(batch_size=3, drop_remainder=True) # the rest of incomplete batch will be dropped
  119. assert data5.get_dataset_size() == 1
  120. assert data5.get_batch_size() == 3
  121. num_iter5 = 0
  122. for _ in data5.create_dict_iterator(num_epochs=1):
  123. num_iter5 += 1
  124. assert num_iter5 == 1
  125. def test_places365_pk_sampler():
  126. """
  127. Feature: test_places365_pk_sampler.
  128. Description: test usage of Places365Dataset with PKSampler.
  129. Expectation: get correct number of data.
  130. """
  131. logger.info("Test Places365Dataset Op with PKSampler")
  132. sampler = ds.PKSampler(1)
  133. data = ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True, sampler=sampler)
  134. num_iter = 0
  135. golden = [0, 1]
  136. label_list = []
  137. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  138. label_list.append(item["label"])
  139. num_iter += 1
  140. np.testing.assert_array_equal(golden, label_list)
  141. assert num_iter == 2
  142. def test_places365_sequential_sampler():
  143. """
  144. Feature: test_places365_sequential_sampler.
  145. Description: test usage of Places365Dataset with SequentialSampler.
  146. Expectation: get correct number of data.
  147. """
  148. logger.info("Test Places365Dataset Op with SequentialSampler")
  149. num_samples = 4
  150. sampler = ds.SequentialSampler(num_samples=num_samples)
  151. data1 = ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True, sampler=sampler)
  152. data2 = ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True, shuffle=False, num_samples=num_samples)
  153. label_list1, label_list2 = [], []
  154. num_iter = 0
  155. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)):
  156. label_list1.append(item1["label"].asnumpy())
  157. label_list2.append(item2["label"].asnumpy())
  158. num_iter += 1
  159. np.testing.assert_array_equal(label_list1, label_list2)
  160. assert num_iter == num_samples
  161. def test_places365_exception():
  162. """
  163. Feature: test_places365_exception.
  164. Description: test error cases for Places365Dataset.
  165. Expectation: raise exception.
  166. """
  167. logger.info("Test error cases for Places365Dataset")
  168. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  169. with pytest.raises(RuntimeError, match=error_msg_1):
  170. ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True, shuffle=False, sampler=ds.PKSampler(3))
  171. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  172. with pytest.raises(RuntimeError, match=error_msg_2):
  173. ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True,
  174. sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
  175. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  176. with pytest.raises(RuntimeError, match=error_msg_3):
  177. ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True, num_shards=4)
  178. error_msg_4 = "shard_id is specified but num_shards is not"
  179. with pytest.raises(RuntimeError, match=error_msg_4):
  180. ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True, shard_id=0)
  181. error_msg_5 = "Input shard_id is not within the required interval"
  182. with pytest.raises(ValueError, match=error_msg_5):
  183. ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True, num_shards=2, shard_id=-1)
  184. with pytest.raises(ValueError, match=error_msg_5):
  185. ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True, num_shards=2, shard_id=2)
  186. with pytest.raises(ValueError, match=error_msg_5):
  187. ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True, num_shards=2, shard_id=5)
  188. error_msg_6 = "num_parallel_workers exceeds"
  189. with pytest.raises(ValueError, match=error_msg_6):
  190. ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True, shuffle=False, num_parallel_workers=0)
  191. with pytest.raises(ValueError, match=error_msg_6):
  192. ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True, shuffle=False, num_parallel_workers=256)
  193. with pytest.raises(ValueError, match=error_msg_6):
  194. ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True, shuffle=False, num_parallel_workers=-2)
  195. error_msg_7 = "Argument shard_id"
  196. with pytest.raises(TypeError, match=error_msg_7):
  197. ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True, num_shards=2, shard_id="0")
  198. def test_places365_visualize(plot=False):
  199. """
  200. Feature: test_places365_visualize.
  201. Description: visualize Places365Dataset results.
  202. Expectation: get correct number of data and plot them.
  203. """
  204. logger.info("Test Places365Dataset visualization")
  205. data1 = ds.Places365Dataset(DATA_DIR, usage='val', small=True, decode=True, num_samples=4, shuffle=False)
  206. num_iter = 0
  207. image_list, label_list = [], []
  208. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  209. image = item["image"]
  210. label = item["label"]
  211. image_list.append(image)
  212. label_list.append("label {}".format(label))
  213. assert isinstance(image, np.ndarray)
  214. assert image.shape == (256, 256, 3)
  215. assert image.dtype == np.uint8
  216. assert label.dtype == np.uint32
  217. num_iter += 1
  218. assert num_iter == 4
  219. if plot:
  220. visualize_dataset(image_list, label_list)
  221. def test_places365_usage():
  222. """
  223. Feature: test_places365_usage.
  224. Description: validate Places365Dataset image readings.
  225. Expectation: get correct number of data.
  226. """
  227. logger.info("Test Places365Dataset usage flag")
  228. def test_config(usage, places365_path=None):
  229. if places365_path is None:
  230. places365_path = DATA_DIR
  231. try:
  232. data = ds.Places365Dataset(places365_path, usage=usage, small=True, decode=True, shuffle=False)
  233. num_rows = 0
  234. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  235. num_rows += 1
  236. except (ValueError, TypeError, RuntimeError) as e:
  237. print(str(e))
  238. return str(e)
  239. return num_rows
  240. assert test_config("val") == 4
  241. assert "usage is not within the valid set of ['train-standard', 'train-challenge', 'val']" in test_config("invalid")
  242. assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
  243. # change this directory to the folder that contains all places365 files
  244. train_standard_files_path = DATA_DIR
  245. # the following tests on the entire datasets
  246. if train_standard_files_path is not None:
  247. assert test_config("train-standard", train_standard_files_path) == 4
  248. assert test_config("val", train_standard_files_path) == 4
  249. # change this directory to the folder that contains all places365 files
  250. train_challenge_files_path = DATA_DIR
  251. # the following tests on the entire datasets
  252. if train_challenge_files_path is not None:
  253. assert test_config("train-challenge", train_challenge_files_path) == 4
  254. assert test_config("val", train_standard_files_path) == 4
  255. if __name__ == '__main__':
  256. test_places365_content_check()
  257. test_places365_basic()
  258. test_places365_pk_sampler()
  259. test_places365_sequential_sampler()
  260. test_places365_exception()
  261. test_places365_visualize(plot=True)
  262. test_places365_usage()