You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_lfw.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import re
  16. import pytest
  17. import mindspore.dataset as ds
  18. from mindspore import log as logger
  19. DATA_DIR = "../data/dataset/testLFW"
  20. def test_lfw_basic():
  21. """
  22. Feature: LFW
  23. Description: test basic usage of LFW
  24. Expectation: the dataset is as expected
  25. """
  26. logger.info("Test Case basic")
  27. # define parameters.
  28. repeat_count = 2
  29. # apply dataset operations.
  30. data1 = ds.LFWDataset(DATA_DIR)
  31. data1 = data1.repeat(repeat_count)
  32. num_iter = 0
  33. # each data is a dictionary.
  34. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  35. # in this example, each dictionary has keys "image" and "label".
  36. logger.info("image is {}".format(item["image"]))
  37. logger.info("label is {}".format(item["label"]))
  38. num_iter += 1
  39. logger.info("Number of data in data1: {}".format(num_iter))
  40. assert num_iter == 8
  41. def test_lfw_task():
  42. """
  43. Feature: LFW
  44. Description: test basic usage of LFW
  45. Expectation: the dataset is as expected
  46. """
  47. logger.info("Test Case basic")
  48. # define parameters.
  49. repeat_count = 2
  50. # apply dataset operations.
  51. data1 = ds.LFWDataset(DATA_DIR, task="pairs", usage="all")
  52. data1 = data1.repeat(repeat_count)
  53. num_iter = 0
  54. # each data is a dictionary.
  55. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  56. # in this example, each dictionary has keys "image" and "label".
  57. logger.info("image1 is {}".format(item["image1"]))
  58. logger.info("image2 is {}".format(item["image2"]))
  59. logger.info("label is {}".format(item["label"]))
  60. num_iter += 1
  61. logger.info("Number of data in data1: {}".format(num_iter))
  62. assert num_iter == 16
  63. def test_lfw_usage():
  64. """
  65. Feature: LFW
  66. Description: test basic usage of LFW
  67. Expectation: the dataset is as expected
  68. """
  69. logger.info("Test Case basic")
  70. # define parameters.
  71. repeat_count = 2
  72. # apply dataset operations.
  73. data1 = ds.LFWDataset(DATA_DIR, usage="test")
  74. data1 = data1.repeat(repeat_count)
  75. num_iter = 0
  76. # each data is a dictionary.
  77. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  78. # in this example, each dictionary has keys "image" and "label".
  79. logger.info("image is {}".format(item["image"]))
  80. logger.info("label is {}".format(item["label"]))
  81. num_iter += 1
  82. logger.info("Number of data in data1: {}".format(num_iter))
  83. assert num_iter == 6
  84. def test_lfw_image_set():
  85. """
  86. Feature: LFW
  87. Description: test basic usage of LFW
  88. Expectation: the dataset is as expected
  89. """
  90. logger.info("Test Case basic")
  91. # define parameters.
  92. repeat_count = 2
  93. # apply dataset operations.
  94. data1 = ds.LFWDataset(DATA_DIR, image_set="funneled")
  95. data1 = data1.repeat(repeat_count)
  96. num_iter = 0
  97. # each data is a dictionary.
  98. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  99. # in this example, each dictionary has keys "image" and "label".
  100. logger.info("image is {}".format(item["image"]))
  101. logger.info("label is {}".format(item["label"]))
  102. num_iter += 1
  103. logger.info("Number of data in data1: {}".format(num_iter))
  104. assert num_iter == 8
  105. def test_lfw_num_samples():
  106. """
  107. Feature: LFW
  108. Description: test basic usage of LFW
  109. Expectation: the dataset is as expected
  110. """
  111. logger.info("Test Case numSamples")
  112. # define parameters.
  113. repeat_count = 2
  114. # apply dataset operations.
  115. data1 = ds.LFWDataset(DATA_DIR, num_samples=4, num_parallel_workers=2)
  116. data1 = data1.repeat(repeat_count)
  117. num_iter = 0
  118. # each data is a dictionary.
  119. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  120. # in this example, each dictionary has keys "image" and "label".
  121. logger.info("image is {}".format(item["image"]))
  122. logger.info("label is {}".format(item["label"]))
  123. num_iter += 1
  124. logger.info("Number of data in data1: {}".format(num_iter))
  125. assert num_iter == 8
  126. random_sampler = ds.RandomSampler(num_samples=2, replacement=True)
  127. data1 = ds.LFWDataset(DATA_DIR, num_parallel_workers=2,
  128. sampler=random_sampler)
  129. num_iter = 0
  130. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  131. num_iter += 1
  132. assert num_iter == 2
  133. random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
  134. data1 = ds.LFWDataset(DATA_DIR, num_parallel_workers=2,
  135. sampler=random_sampler)
  136. num_iter = 0
  137. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  138. num_iter += 1
  139. assert num_iter == 3
  140. def test_lfw_num_shards():
  141. """
  142. Feature: LFW
  143. Description: test basic usage of LFW
  144. Expectation: the dataset is as expected
  145. """
  146. logger.info("Test Case numShards")
  147. # define parameters.
  148. repeat_count = 2
  149. # apply dataset operations.
  150. data1 = ds.LFWDataset(DATA_DIR, num_shards=5, shard_id=1)
  151. data1 = data1.repeat(repeat_count)
  152. num_iter = 0
  153. # each data is a dictionary.
  154. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  155. # in this example, each dictionary has keys "image" and "label".
  156. logger.info("image is {}".format(item["image"]))
  157. logger.info("label is {}".format(item["label"]))
  158. num_iter += 1
  159. logger.info("Number of data in data1: {}".format(num_iter))
  160. assert num_iter == 2
  161. def test_lfw_shard_id():
  162. """
  163. Feature: LFW
  164. Description: test basic usage of LFW
  165. Expectation: the dataset is as expected
  166. """
  167. logger.info("Test Case withShardID")
  168. # define parameters.
  169. repeat_count = 2
  170. # apply dataset operations.
  171. data1 = ds.LFWDataset(DATA_DIR, num_shards=4, shard_id=1)
  172. data1 = data1.repeat(repeat_count)
  173. num_iter = 0
  174. # each data is a dictionary.
  175. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  176. # in this example, each dictionary has keys "image" and "label".
  177. logger.info("image is {}".format(item["image"]))
  178. logger.info("label is {}".format(item["label"]))
  179. num_iter += 1
  180. logger.info("Number of data in data1: {}".format(num_iter))
  181. assert num_iter == 2
  182. def test_lfw_no_shuffle():
  183. """
  184. Feature: LFW
  185. Description: test dataset of LFW
  186. Expectation: the dataset is as expected
  187. """
  188. logger.info("Test Case noShuffle")
  189. # define parameters.
  190. repeat_count = 2
  191. # apply dataset operations.
  192. data1 = ds.LFWDataset(DATA_DIR, shuffle=False)
  193. data1 = data1.repeat(repeat_count)
  194. num_iter = 0
  195. # each data is a dictionary.
  196. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  197. # in this example, each dictionary has keys "image" and "label".
  198. logger.info("image is {}".format(item["image"]))
  199. logger.info("label is {}".format(item["label"]))
  200. num_iter += 1
  201. logger.info("Number of data in data1: {}".format(num_iter))
  202. assert num_iter == 8
  203. def test_lfw_decode():
  204. """
  205. Feature: LFW
  206. Description: test basic usage of LFW
  207. Expectation: the dataset is as expected
  208. """
  209. logger.info("Test Case decode")
  210. # define parameters.
  211. repeat_count = 2
  212. # apply dataset operations.
  213. data1 = ds.LFWDataset(DATA_DIR, decode=True)
  214. data1 = data1.repeat(repeat_count)
  215. num_iter = 0
  216. # each data is a dictionary.
  217. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  218. # in this example, each dictionary has keys "image" and "label".
  219. logger.info("image is {}".format(item["image"]))
  220. logger.info("label is {}".format(item["label"]))
  221. num_iter += 1
  222. logger.info("Number of data in data1: {}".format(num_iter))
  223. assert num_iter == 8
  224. def test_sequential_sampler():
  225. """
  226. Feature: LFW
  227. Description: test basic usage of LFW
  228. Expectation: the dataset is as expected
  229. """
  230. logger.info("Test Case SequentialSampler")
  231. # define parameters.
  232. repeat_count = 2
  233. # apply dataset operations.
  234. sampler = ds.SequentialSampler(num_samples=3)
  235. data1 = ds.LFWDataset(DATA_DIR, sampler=sampler)
  236. data1 = data1.repeat(repeat_count)
  237. result = []
  238. num_iter = 0
  239. # each data is a dictionary.
  240. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  241. # in this example, each dictionary has keys "image" and "label".
  242. result.append(item["label"])
  243. num_iter += 1
  244. assert num_iter == 6
  245. logger.info("Result: {}".format(result))
  246. def test_random_and_sequentialchained_sampler():
  247. """
  248. Feature: LFW
  249. Description: test basic usage of LFW
  250. Expectation: the dataset is as expected
  251. """
  252. logger.info("Test Case Chained Sampler - Random and Sequential, with repeat")
  253. # Create chained sampler, random and sequential.
  254. sampler = ds.RandomSampler()
  255. child_sampler = ds.SequentialSampler()
  256. sampler.add_child(child_sampler)
  257. # Create LFWDataset with sampler.
  258. data1 = ds.LFWDataset(DATA_DIR, sampler=sampler)
  259. data1 = data1.repeat(count=3)
  260. # Verify dataset size.
  261. data1_size = data1.get_dataset_size()
  262. logger.info("dataset size is: {}".format(data1_size))
  263. assert data1_size == 12
  264. # Verify number of iterations.
  265. num_iter = 0
  266. # each data is a dictionary.
  267. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  268. # in this example, each dictionary has keys "image" and "label".
  269. logger.info("image is {}".format(item["image"]))
  270. logger.info("label is {}".format(item["label"]))
  271. num_iter += 1
  272. logger.info("Number of data in data1: {}".format(num_iter))
  273. assert num_iter == 12
  274. def test_lfw_rename():
  275. """
  276. Feature: LFW
  277. Description: test basic usage of LFW
  278. Expectation: the dataset is as expected
  279. """
  280. logger.info("Test Case rename")
  281. # define parameters.
  282. repeat_count = 2
  283. # apply dataset operations.
  284. data1 = ds.LFWDataset(DATA_DIR, num_samples=4)
  285. data1 = data1.repeat(repeat_count)
  286. num_iter = 0
  287. # each data is a dictionary.
  288. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  289. # in this example, each dictionary has keys "image" and "label".
  290. logger.info("image is {}".format(item["image"]))
  291. logger.info("label is {}".format(item["label"]))
  292. num_iter += 1
  293. logger.info("Number of data in data1: {}".format(num_iter))
  294. assert num_iter == 8
  295. data1 = data1.rename(input_columns=["image"], output_columns="image2")
  296. num_iter = 0
  297. # each data is a dictionary.
  298. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  299. # in this example, each dictionary has keys "image" and "label".
  300. logger.info("image is {}".format(item["image2"]))
  301. logger.info("label is {}".format(item["label"]))
  302. num_iter += 1
  303. logger.info("Number of data in data1: {}".format(num_iter))
  304. assert num_iter == 8
  305. def test_lfw_zip():
  306. """
  307. Feature: LFW
  308. Description: test basic usage of LFW
  309. Expectation: the dataset is as expected
  310. """
  311. logger.info("Test Case zip")
  312. # define parameters.
  313. repeat_count = 2
  314. # apply dataset operations.
  315. data1 = ds.LFWDataset(DATA_DIR, num_samples=3)
  316. data2 = ds.LFWDataset(DATA_DIR, num_samples=3)
  317. data1 = data1.repeat(repeat_count)
  318. # rename dataset2 for no conflict.
  319. data2 = data2.rename(input_columns=["image", "label"],
  320. output_columns=["image1", "label1"])
  321. data3 = ds.zip((data1, data2))
  322. num_iter = 0
  323. # each data is a dictionary.
  324. for item in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
  325. # in this example, each dictionary has keys "image" and "label".
  326. logger.info("image is {}".format(item["image"]))
  327. logger.info("label is {}".format(item["label"]))
  328. num_iter += 1
  329. logger.info("Number of data in data1: {}".format(num_iter))
  330. assert num_iter == 3
  331. def test_lfw_exception():
  332. """
  333. Feature: LFW
  334. Description: test error cases of LFW
  335. Expectation: throw exception correctly
  336. """
  337. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  338. with pytest.raises(RuntimeError, match=error_msg_1):
  339. ds.LFWDataset(DATA_DIR, shuffle=False, decode=True, sampler=ds.SequentialSampler(1))
  340. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  341. with pytest.raises(RuntimeError, match=error_msg_2):
  342. ds.LFWDataset(DATA_DIR, sampler=ds.SequentialSampler(1), decode=True, num_shards=2, shard_id=0)
  343. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  344. with pytest.raises(RuntimeError, match=error_msg_3):
  345. ds.LFWDataset(DATA_DIR, decode=True, num_shards=10)
  346. error_msg_4 = "shard_id is specified but num_shards is not"
  347. with pytest.raises(RuntimeError, match=error_msg_4):
  348. ds.LFWDataset(DATA_DIR, decode=True, shard_id=0)
  349. error_msg_5 = "Input shard_id is not within the required interval"
  350. with pytest.raises(ValueError, match=error_msg_5):
  351. ds.LFWDataset(DATA_DIR, decode=True, num_shards=5, shard_id=-1)
  352. with pytest.raises(ValueError, match=error_msg_5):
  353. ds.LFWDataset(DATA_DIR, decode=True, num_shards=5, shard_id=5)
  354. error_msg_6 = "num_parallel_workers exceeds"
  355. with pytest.raises(ValueError, match=error_msg_6):
  356. ds.LFWDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=0)
  357. with pytest.raises(ValueError, match=error_msg_6):
  358. ds.LFWDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=256)
  359. error_msg_7 = "Argument shard_id"
  360. with pytest.raises(TypeError, match=error_msg_7):
  361. ds.LFWDataset(DATA_DIR, decode=True, num_shards=2, shard_id="0")
  362. error_msg_8 = "does not exist or is not a directory or permission denied!"
  363. with pytest.raises(ValueError, match=error_msg_8):
  364. all_data = ds.LFWDataset("../data/dataset/testLFW2", decode=True)
  365. for _ in all_data.create_dict_iterator(num_epochs=1):
  366. pass
  367. error_msg_9 = "Input task is not within the valid set of ['people', 'pairs']."
  368. with pytest.raises(ValueError, match=re.escape(error_msg_9)):
  369. all_data = ds.LFWDataset(DATA_DIR, task="all")
  370. for _ in all_data.create_dict_iterator(num_epochs=1):
  371. pass
  372. error_msg_10 = "Input usage is not within the valid set of ['10fold', 'train', 'test', 'all']."
  373. with pytest.raises(ValueError, match=re.escape(error_msg_10)):
  374. all_data = ds.LFWDataset(DATA_DIR, usage="many")
  375. for _ in all_data.create_dict_iterator(num_epochs=1):
  376. pass
  377. error_msg_11 = "Input image_set is not within the valid set of ['original', 'funneled', 'deepfunneled']."
  378. with pytest.raises(ValueError, match=re.escape(error_msg_11)):
  379. all_data = ds.LFWDataset(DATA_DIR, image_set="all")
  380. for _ in all_data.create_dict_iterator(num_epochs=1):
  381. pass
  382. error_msg_12 = "Argument decode with value 123 is not of type [<class 'bool'>], but got <class 'int'>."
  383. with pytest.raises(TypeError, match=re.escape(error_msg_12)):
  384. all_data = ds.LFWDataset(DATA_DIR, decode=123)
  385. for _ in all_data.create_dict_iterator(num_epochs=1):
  386. pass
  387. if __name__ == '__main__':
  388. test_lfw_basic()
  389. test_lfw_task()
  390. test_lfw_usage()
  391. test_lfw_image_set()
  392. test_lfw_num_samples()
  393. test_sequential_sampler()
  394. test_random_and_sequentialchained_sampler()
  395. test_lfw_num_shards()
  396. test_lfw_shard_id()
  397. test_lfw_no_shuffle()
  398. test_lfw_decode()
  399. test_lfw_rename()
  400. test_lfw_zip()
  401. test_lfw_exception()