You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_photo_tour.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Test PhotoTour dataset operator
  17. """
  18. import os
  19. import matplotlib.pyplot as plt
  20. import numpy as np
  21. import pytest
  22. from PIL import Image
  23. import mindspore.dataset as ds
  24. from mindspore import log as logger
  25. DATA_DIR = "../data/dataset/testPhotoTourData"
  26. NAME = 'liberty'
  27. LEN = 100
  28. def load_photo_tour_dataset(path, name):
  29. """
  30. Feature: load_photo_tour_dataset.
  31. Description: load photo tour.
  32. Expectation: get data of photo tour dataset.
  33. """
  34. def pil2array(img: Image.Image):
  35. """
  36. Convert PIL image type to numpy 2D array
  37. """
  38. return np.array(img.getdata(), dtype=np.uint8).reshape((64, 64, 1))
  39. def find_files(data_dir: str, image_ext_: str):
  40. """
  41. Return a list with the file names of the images containing the patches
  42. """
  43. files = []
  44. # find those files with the specified extension
  45. for file_dir in os.listdir(data_dir):
  46. if file_dir.endswith(image_ext_):
  47. files.append(os.path.join(data_dir, file_dir))
  48. return sorted(files) # sort files in ascend order to keep relations
  49. patches = []
  50. list_files = find_files(os.path.realpath(os.path.join(path, name)), 'bmp')
  51. idx = 0
  52. for fpath in list_files:
  53. img = Image.open(fpath)
  54. for y in range(0, 1024, 64):
  55. for x in range(0, 1024, 64):
  56. patch = img.crop((x, y, x + 64, y + 64))
  57. patches.append(pil2array(patch))
  58. idx += 1
  59. if idx > LEN:
  60. break
  61. if idx > LEN:
  62. break
  63. matches_path = os.path.join(os.path.realpath(os.path.join(path, name)), 'm50_100000_100000_0.txt')
  64. matches = []
  65. with open(matches_path, 'r') as f:
  66. for line in f.readlines():
  67. line_split = line.split()
  68. matches.append([int(line_split[0]), int(line_split[3]),
  69. int(line_split[1] == line_split[4])])
  70. return patches, matches
  71. def visualize_dataset(images1, images2, matches):
  72. """
  73. Feature: visualize_dataset.
  74. Description: visualize photo tour dataset.
  75. Expectation: plot images.
  76. """
  77. num_samples = len(images1)
  78. for i in range(num_samples):
  79. plt.subplot(1, num_samples, i + 1)
  80. plt.imshow(images1[i].squeeze(), cmap=plt.cm.gray)
  81. plt.title(matches[i])
  82. num_samples = len(images2)
  83. for i in range(num_samples):
  84. plt.subplot(2, num_samples, i + 1)
  85. plt.imshow(images2[i].squeeze(), cmap=plt.cm.gray)
  86. plt.title(matches[i])
  87. plt.show()
  88. def test_photo_tour_content_check():
  89. """
  90. Feature: test_photo_tour_content_check.
  91. Description: validate PhotoTourDataset image readings.
  92. Expectation: get correct number of data and correct content.
  93. """
  94. logger.info("Test PhotoTourDataset Op with content check")
  95. data1 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_samples=10, shuffle=False)
  96. images, matches = load_photo_tour_dataset(DATA_DIR, NAME)
  97. num_iter = 0
  98. # in this example, each dictionary has keys "image1" "image2" and "matches"
  99. for i, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
  100. np.testing.assert_array_equal(data["image1"], images[matches[i][0]])
  101. np.testing.assert_array_equal(data["image2"], images[matches[i][1]])
  102. np.testing.assert_array_equal(data["matches"], matches[i][2])
  103. num_iter += 1
  104. assert num_iter == 10
  105. def test_photo_tour_basic():
  106. """
  107. Feature: test_photo_tour_basic.
  108. Description: test basic usage of PhotoTourDataset.
  109. Expectation: get correct number of data.
  110. """
  111. logger.info("Test PhotoTourDataset Op")
  112. # case 1: test loading whole dataset
  113. data1 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test')
  114. num_iter1 = 0
  115. for _ in data1.create_dict_iterator(num_epochs=1):
  116. num_iter1 += 1
  117. assert num_iter1 == 16
  118. # case 2: test num_samples
  119. data2 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_samples=10)
  120. num_iter2 = 0
  121. for _ in data2.create_dict_iterator(num_epochs=1):
  122. num_iter2 += 1
  123. assert num_iter2 == 10
  124. # case 3: test repeat
  125. data3 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_samples=5)
  126. data3 = data3.repeat(5)
  127. num_iter3 = 0
  128. for _ in data3.create_dict_iterator(num_epochs=1):
  129. num_iter3 += 1
  130. assert num_iter3 == 25
  131. # case 4: test batch with drop_remainder=False
  132. data4 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_samples=10)
  133. assert data4.get_dataset_size() == 10
  134. assert data4.get_batch_size() == 1
  135. data4 = data4.batch(batch_size=7) # drop_remainder is default to be False
  136. assert data4.get_dataset_size() == 2
  137. assert data4.get_batch_size() == 7
  138. num_iter4 = 0
  139. for _ in data4.create_dict_iterator(num_epochs=1):
  140. num_iter4 += 1
  141. assert num_iter4 == 2
  142. # case 5: test batch with drop_remainder=True
  143. data5 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_samples=10)
  144. assert data5.get_dataset_size() == 10
  145. assert data5.get_batch_size() == 1
  146. data5 = data5.batch(batch_size=7, drop_remainder=True) # the rest of incomplete batch will be dropped
  147. assert data5.get_dataset_size() == 1
  148. assert data5.get_batch_size() == 7
  149. num_iter5 = 0
  150. for _ in data5.create_dict_iterator(num_epochs=1):
  151. num_iter5 += 1
  152. assert num_iter5 == 1
  153. # case 6: test get_col_names
  154. data6 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_samples=10)
  155. assert data6.get_col_names() == ['image1', 'image2', 'matches']
  156. def test_photo_tour_pk_sampler():
  157. """
  158. Feature: test_photo_tour_pk_sampler.
  159. Description: test usage of PhotoTourDataset with PKSampler.
  160. Expectation: get correct number of data.
  161. """
  162. logger.info("Test PhotoTourDataset Op with PKSampler")
  163. golden = [0, 0, 0, 1, 1, 1]
  164. sampler = ds.PKSampler(3)
  165. data = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', sampler=sampler)
  166. num_iter = 0
  167. matches_list = []
  168. for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  169. matches_list.append(item["matches"])
  170. num_iter += 1
  171. np.testing.assert_array_equal(golden, matches_list)
  172. assert num_iter == 6
  173. def test_photo_tour_sequential_sampler():
  174. """
  175. Feature: test_photo_tour_sequential_sampler.
  176. Description: test usage of PhotoTourDataset with SequentialSampler.
  177. Expectation: get correct number of data.
  178. """
  179. logger.info("Test PhotoTourDataset Op with SequentialSampler")
  180. num_samples = 5
  181. sampler = ds.SequentialSampler(num_samples=num_samples)
  182. data1 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', sampler=sampler)
  183. data2 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', shuffle=False, num_samples=num_samples)
  184. matches_list1, matches_list2 = [], []
  185. num_iter = 0
  186. for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)):
  187. matches_list1.append(item1["matches"].asnumpy())
  188. matches_list2.append(item2["matches"].asnumpy())
  189. num_iter += 1
  190. np.testing.assert_array_equal(matches_list1, matches_list2)
  191. assert num_iter == num_samples
  192. def test_photo_tour_exception():
  193. """
  194. Feature: test_photo_tour_exception.
  195. Description: test error cases for PhotoTourDataset.
  196. Expectation: raise exception.
  197. """
  198. logger.info("Test error cases for PhotoTourDataset")
  199. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  200. with pytest.raises(RuntimeError, match=error_msg_1):
  201. ds.PhotoTourDataset(DATA_DIR, NAME, 'test', shuffle=False, sampler=ds.PKSampler(3))
  202. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  203. with pytest.raises(RuntimeError, match=error_msg_2):
  204. ds.PhotoTourDataset(DATA_DIR, NAME, 'test', sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
  205. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  206. with pytest.raises(RuntimeError, match=error_msg_3):
  207. ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_shards=10)
  208. error_msg_4 = "shard_id is specified but num_shards is not"
  209. with pytest.raises(RuntimeError, match=error_msg_4):
  210. ds.PhotoTourDataset(DATA_DIR, NAME, 'test', shard_id=0)
  211. error_msg_5 = "Input shard_id is not within the required interval"
  212. with pytest.raises(ValueError, match=error_msg_5):
  213. ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_shards=5, shard_id=-1)
  214. with pytest.raises(ValueError, match=error_msg_5):
  215. ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_shards=5, shard_id=5)
  216. with pytest.raises(ValueError, match=error_msg_5):
  217. ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_shards=2, shard_id=5)
  218. error_msg_6 = "num_parallel_workers exceeds"
  219. with pytest.raises(ValueError, match=error_msg_6):
  220. ds.PhotoTourDataset(DATA_DIR, NAME, 'test', shuffle=False, num_parallel_workers=0)
  221. with pytest.raises(ValueError, match=error_msg_6):
  222. ds.PhotoTourDataset(DATA_DIR, NAME, 'test', shuffle=False, num_parallel_workers=256)
  223. with pytest.raises(ValueError, match=error_msg_6):
  224. ds.PhotoTourDataset(DATA_DIR, NAME, 'test', shuffle=False, num_parallel_workers=-2)
  225. error_msg_7 = "Argument shard_id"
  226. with pytest.raises(TypeError, match=error_msg_7):
  227. ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_shards=2, shard_id="0")
  228. def test_photo_tour_visualize(plot=False):
  229. """
  230. Feature: test_photo_tour_visualize.
  231. Description: visualize PhotoTourDataset results.
  232. Expectation: get correct number of data and plot them.
  233. """
  234. logger.info("Test PhotoTourDataset visualization")
  235. data1 = ds.PhotoTourDataset(DATA_DIR, NAME, 'test', num_samples=10, shuffle=False)
  236. num_iter = 0
  237. image_list1, image_list2, matches_list = [], [], []
  238. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  239. image1 = item["image1"]
  240. image2 = item["image2"]
  241. matches = item["matches"]
  242. image_list1.append(image1)
  243. image_list2.append(image2)
  244. matches_list.append("matches {}".format(matches))
  245. assert isinstance(image1, np.ndarray)
  246. assert isinstance(image2, np.ndarray)
  247. assert image1.shape == (64, 64, 1)
  248. assert image1.dtype == np.uint8
  249. assert image2.shape == (64, 64, 1)
  250. assert image2.dtype == np.uint8
  251. assert matches.dtype == np.uint32
  252. num_iter += 1
  253. assert num_iter == 10
  254. if plot:
  255. visualize_dataset(image_list1, image_list2, matches_list)
  256. def test_photo_tour_usage():
  257. """
  258. Feature: test_photo_tour_usage.
  259. Description: validate PhotoTourDataset image readings.
  260. Expectation: get correct number of data.
  261. """
  262. logger.info("Test PhotoTourDataset usage flag")
  263. def test_config(photo_tour_path, name, usage):
  264. try:
  265. data = ds.PhotoTourDataset(photo_tour_path, name, usage, shuffle=False)
  266. num_rows = 0
  267. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  268. num_rows += 1
  269. except (ValueError, TypeError, RuntimeError) as e:
  270. return str(e)
  271. return num_rows
  272. assert test_config(DATA_DIR, NAME, "test") == 16
  273. assert test_config(DATA_DIR, NAME, "train") == LEN
  274. assert "usage is not within the valid set of ['train', 'test']" in test_config(DATA_DIR, NAME, "invalid")
  275. assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(DATA_DIR, NAME, ["list"])
  276. if __name__ == '__main__':
  277. test_photo_tour_content_check()
  278. test_photo_tour_basic()
  279. test_photo_tour_pk_sampler()
  280. test_photo_tour_sequential_sampler()
  281. test_photo_tour_exception()
  282. test_photo_tour_visualize(plot=True)
  283. test_photo_tour_usage()