You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_lsun.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Test LSUN dataset operators
  17. """
  18. import pytest
  19. import mindspore.dataset as ds
  20. import mindspore.dataset.vision.c_transforms as vision
  21. from mindspore import log as logger
  22. DATA_DIR = "../data/dataset/testLSUN"
  23. def test_lsun_basic():
  24. """
  25. Feature: LSUN
  26. Description: test basic usage of LSUN
  27. Expectation: the dataset is as expected
  28. """
  29. logger.info("Test Case basic")
  30. # define parameters
  31. repeat_count = 1
  32. # apply dataset operations
  33. data1 = ds.LSUNDataset(DATA_DIR)
  34. data1 = data1.repeat(repeat_count)
  35. num_iter = 0
  36. # each data is a dictionary
  37. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  38. # in this example, each dictionary has keys "image" and "label"
  39. logger.info("image is {}".format(item["image"]))
  40. logger.info("label is {}".format(item["label"]))
  41. num_iter += 1
  42. logger.info("Number of data in data1: {}".format(num_iter))
  43. assert num_iter == 4
  44. def test_lsun_num_samples():
  45. """
  46. Feature: LSUN
  47. Description: test basic usage of LSUN
  48. Expectation: the dataset is as expected
  49. """
  50. logger.info("Test Case num_samples")
  51. # define parameters
  52. repeat_count = 1
  53. # apply dataset operations
  54. data1 = ds.LSUNDataset(DATA_DIR, num_samples=10, num_parallel_workers=2)
  55. data1 = data1.repeat(repeat_count)
  56. num_iter = 0
  57. # each data is a dictionary
  58. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  59. # in this example, each dictionary has keys "image" and "label"
  60. logger.info("image is {}".format(item["image"]))
  61. logger.info("label is {}".format(item["label"]))
  62. num_iter += 1
  63. logger.info("Number of data in data1: {}".format(num_iter))
  64. assert num_iter == 4
  65. random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
  66. data1 = ds.LSUNDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
  67. num_iter = 0
  68. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  69. num_iter += 1
  70. assert num_iter == 3
  71. random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
  72. data1 = ds.LSUNDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
  73. num_iter = 0
  74. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  75. num_iter += 1
  76. assert num_iter == 3
  77. def test_lsun_num_shards():
  78. """
  79. Feature: LSUN
  80. Description: test basic usage of LSUN
  81. Expectation: the dataset is as expected
  82. """
  83. logger.info("Test Case numShards")
  84. # define parameters
  85. repeat_count = 1
  86. # apply dataset operations
  87. data1 = ds.LSUNDataset(DATA_DIR, num_shards=2, shard_id=1)
  88. data1 = data1.repeat(repeat_count)
  89. num_iter = 0
  90. # each data is a dictionary
  91. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  92. # in this example, each dictionary has keys "image" and "label"
  93. logger.info("image is {}".format(item["image"]))
  94. logger.info("label is {}".format(item["label"]))
  95. num_iter += 1
  96. logger.info("Number of data in data1: {}".format(num_iter))
  97. assert num_iter == 2
  98. def test_lsun_shard_id():
  99. """
  100. Feature: LSUN
  101. Description: test basic usage of LSUN
  102. Expectation: the dataset is as expected
  103. """
  104. logger.info("Test Case withShardID")
  105. # define parameters
  106. repeat_count = 1
  107. # apply dataset operations
  108. data1 = ds.LSUNDataset(DATA_DIR, num_shards=2, shard_id=0)
  109. data1 = data1.repeat(repeat_count)
  110. num_iter = 0
  111. # each data is a dictionary
  112. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  113. # in this example, each dictionary has keys "image" and "label"
  114. logger.info("image is {}".format(item["image"]))
  115. logger.info("label is {}".format(item["label"]))
  116. num_iter += 1
  117. logger.info("Number of data in data1: {}".format(num_iter))
  118. assert num_iter == 2
  119. def test_lsun_no_shuffle():
  120. """
  121. Feature: LSUN
  122. Description: test basic usage of LSUN
  123. Expectation: the dataset is as expected
  124. """
  125. logger.info("Test Case noShuffle")
  126. # define parameters
  127. repeat_count = 1
  128. # apply dataset operations
  129. data1 = ds.LSUNDataset(DATA_DIR, shuffle=False)
  130. data1 = data1.repeat(repeat_count)
  131. num_iter = 0
  132. # each data is a dictionary
  133. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  134. # in this example, each dictionary has keys "image" and "label"
  135. logger.info("image is {}".format(item["image"]))
  136. logger.info("label is {}".format(item["label"]))
  137. num_iter += 1
  138. logger.info("Number of data in data1: {}".format(num_iter))
  139. assert num_iter == 4
  140. def test_lsun_extra_shuffle():
  141. """
  142. Feature: LSUN
  143. Description: test basic usage of LSUN
  144. Expectation: the dataset is as expected
  145. """
  146. logger.info("Test Case extra_shuffle")
  147. # define parameters
  148. repeat_count = 2
  149. # apply dataset operations
  150. data1 = ds.LSUNDataset(DATA_DIR, shuffle=True)
  151. data1 = data1.shuffle(buffer_size=5)
  152. data1 = data1.repeat(repeat_count)
  153. num_iter = 0
  154. # each data is a dictionary
  155. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  156. # in this example, each dictionary has keys "image" and "label"
  157. logger.info("image is {}".format(item["image"]))
  158. logger.info("label is {}".format(item["label"]))
  159. num_iter += 1
  160. logger.info("Number of data in data1: {}".format(num_iter))
  161. assert num_iter == 8
  162. def test_lsun_decode():
  163. """
  164. Feature: LSUN
  165. Description: test basic usage of LSUN
  166. Expectation: the dataset is as expected
  167. """
  168. logger.info("Test Case decode")
  169. # define parameters
  170. repeat_count = 1
  171. # apply dataset operations
  172. data1 = ds.LSUNDataset(DATA_DIR, decode=True)
  173. data1 = data1.repeat(repeat_count)
  174. num_iter = 0
  175. # each data is a dictionary
  176. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  177. # in this example, each dictionary has keys "image" and "label"
  178. logger.info("image is {}".format(item["image"]))
  179. logger.info("label is {}".format(item["label"]))
  180. num_iter += 1
  181. logger.info("Number of data in data1: {}".format(num_iter))
  182. assert num_iter == 4
  183. def test_sequential_sampler():
  184. """
  185. Feature: LSUN
  186. Description: test basic usage of LSUN
  187. Expectation: the dataset is as expected
  188. """
  189. logger.info("Test Case SequentialSampler")
  190. # define parameters
  191. repeat_count = 1
  192. # apply dataset operations
  193. sampler = ds.SequentialSampler(num_samples=10)
  194. data1 = ds.LSUNDataset(DATA_DIR, usage="train", sampler=sampler)
  195. data1 = data1.repeat(repeat_count)
  196. result = []
  197. num_iter = 0
  198. # each data is a dictionary
  199. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  200. # in this example, each dictionary has keys "image" and "label"
  201. result.append(item["label"])
  202. num_iter += 1
  203. assert num_iter == 2
  204. logger.info("Result: {}".format(result))
  205. def test_random_sampler():
  206. """
  207. Feature: LSUN
  208. Description: test basic usage of LSUN
  209. Expectation: the dataset is as expected
  210. """
  211. logger.info("Test Case RandomSampler")
  212. # define parameters
  213. repeat_count = 1
  214. # apply dataset operations
  215. sampler = ds.RandomSampler()
  216. data1 = ds.LSUNDataset(DATA_DIR, usage="train", sampler=sampler)
  217. data1 = data1.repeat(repeat_count)
  218. num_iter = 0
  219. # each data is a dictionary
  220. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  221. # in this example, each dictionary has keys "image" and "label"
  222. logger.info("image is {}".format(item["image"]))
  223. logger.info("label is {}".format(item["label"]))
  224. num_iter += 1
  225. logger.info("Number of data in data1: {}".format(num_iter))
  226. assert num_iter == 2
  227. def test_distributed_sampler():
  228. """
  229. Feature: LSUN
  230. Description: test basic usage of LSUN
  231. Expectation: the dataset is as expected
  232. """
  233. logger.info("Test Case DistributedSampler")
  234. # define parameters
  235. repeat_count = 1
  236. # apply dataset operations
  237. sampler = ds.DistributedSampler(2, 1)
  238. data1 = ds.LSUNDataset(DATA_DIR, usage="train", sampler=sampler)
  239. data1 = data1.repeat(repeat_count)
  240. num_iter = 0
  241. # each data is a dictionary
  242. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  243. # in this example, each dictionary has keys "image" and "label"
  244. logger.info("image is {}".format(item["image"]))
  245. logger.info("label is {}".format(item["label"]))
  246. num_iter += 1
  247. logger.info("Number of data in data1: {}".format(num_iter))
  248. assert num_iter == 1
  249. def test_pk_sampler():
  250. """
  251. Feature: LSUN
  252. Description: test basic usage of LSUN
  253. Expectation: the dataset is as expected
  254. """
  255. logger.info("Test Case PKSampler")
  256. # define parameters
  257. repeat_count = 1
  258. # apply dataset operations
  259. sampler = ds.PKSampler(1)
  260. data1 = ds.LSUNDataset(DATA_DIR, usage="train", sampler=sampler)
  261. data1 = data1.repeat(repeat_count)
  262. num_iter = 0
  263. # each data is a dictionary
  264. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  265. # in this example, each dictionary has keys "image" and "label"
  266. logger.info("image is {}".format(item["image"]))
  267. logger.info("label is {}".format(item["label"]))
  268. num_iter += 1
  269. logger.info("Number of data in data1: {}".format(num_iter))
  270. assert num_iter == 2
  271. def test_chained_sampler():
  272. """
  273. Feature: LSUN
  274. Description: test basic usage of LSUN
  275. Expectation: the dataset is as expected
  276. """
  277. logger.info("Test Case Chained Sampler - Random and Sequential, with repeat")
  278. # Create chained sampler, random and sequential
  279. sampler = ds.RandomSampler()
  280. child_sampler = ds.SequentialSampler()
  281. sampler.add_child(child_sampler)
  282. # Create LSUNDataset with sampler
  283. data1 = ds.LSUNDataset(DATA_DIR, usage="train", sampler=sampler)
  284. data1 = data1.repeat(count=3)
  285. # Verify dataset size
  286. data1_size = data1.get_dataset_size()
  287. logger.info("dataset size is: {}".format(data1_size))
  288. assert data1_size == 6
  289. # Verify number of iterations
  290. num_iter = 0
  291. # each data is a dictionary
  292. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  293. # in this example, each dictionary has keys "image" and "label"
  294. logger.info("image is {}".format(item["image"]))
  295. logger.info("label is {}".format(item["label"]))
  296. num_iter += 1
  297. logger.info("Number of data in data1: {}".format(num_iter))
  298. assert num_iter == 6
  299. def test_lsun_test_dataset():
  300. """
  301. Feature: LSUN
  302. Description: test basic usage of LSUN
  303. Expectation: the dataset is as expected
  304. """
  305. logger.info("Test Case usage")
  306. # apply dataset operations
  307. data1 = ds.LSUNDataset(DATA_DIR, usage="test", num_samples=8)
  308. num_iter = 0
  309. # each data is a dictionary
  310. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  311. # in this example, each dictionary has keys "image" and "label"
  312. logger.info("image is {}".format(item["image"]))
  313. logger.info("label is {}".format(item["label"]))
  314. num_iter += 1
  315. logger.info("Number of data in data1: {}".format(num_iter))
  316. assert num_iter == 1
  317. def test_lsun_valid_dataset():
  318. """
  319. Feature: LSUN
  320. Description: test basic usage of LSUN
  321. Expectation: the dataset is as expected
  322. """
  323. logger.info("Test Case usage")
  324. # apply dataset operations
  325. data1 = ds.LSUNDataset(DATA_DIR, usage="valid", num_samples=8)
  326. num_iter = 0
  327. # each data is a dictionary
  328. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  329. # in this example, each dictionary has keys "image" and "label"
  330. logger.info("image is {}".format(item["image"]))
  331. logger.info("label is {}".format(item["label"]))
  332. num_iter += 1
  333. logger.info("Number of data in data1: {}".format(num_iter))
  334. assert num_iter == 2
  335. def test_lsun_train_dataset():
  336. """
  337. Feature: LSUN
  338. Description: test basic usage of LSUN
  339. Expectation: the dataset is as expected
  340. """
  341. logger.info("Test Case usage")
  342. # apply dataset operations
  343. data1 = ds.LSUNDataset(DATA_DIR, usage="train", num_samples=8)
  344. num_iter = 0
  345. # each data is a dictionary
  346. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  347. # in this example, each dictionary has keys "image" and "label"
  348. logger.info("image is {}".format(item["image"]))
  349. logger.info("label is {}".format(item["label"]))
  350. num_iter += 1
  351. logger.info("Number of data in data1: {}".format(num_iter))
  352. assert num_iter == 2
  353. def test_lsun_all_dataset():
  354. """
  355. Feature: LSUN
  356. Description: test basic usage of LSUN
  357. Expectation: the dataset is as expected
  358. """
  359. logger.info("Test Case usage")
  360. # apply dataset operations
  361. data1 = ds.LSUNDataset(DATA_DIR, usage="all", num_samples=8)
  362. num_iter = 0
  363. # each data is a dictionary
  364. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  365. # in this example, each dictionary has keys "image" and "label"
  366. logger.info("image is {}".format(item["image"]))
  367. logger.info("label is {}".format(item["label"]))
  368. num_iter += 1
  369. logger.info("Number of data in data1: {}".format(num_iter))
  370. assert num_iter == 4
  371. def test_lsun_classes():
  372. """
  373. Feature: LSUN
  374. Description: test classes of LSUN
  375. Expectation: the dataset is as expected
  376. """
  377. logger.info("Test Case usage")
  378. # apply dataset operations
  379. data1 = ds.LSUNDataset(DATA_DIR, usage="train", classes=["bedroom"], num_samples=8)
  380. num_iter = 0
  381. # each data is a dictionary
  382. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  383. # in this example, each dictionary has keys "image" and "label"
  384. logger.info("image is {}".format(item["image"]))
  385. logger.info("label is {}".format(item["label"]))
  386. num_iter += 1
  387. logger.info("Number of data in data1: {}".format(num_iter))
  388. assert num_iter == 1
  389. def test_lsun_zip():
  390. """
  391. Feature: LSUN
  392. Description: test basic usage of LSUN
  393. Expectation: the dataset is as expected
  394. """
  395. logger.info("Test Case zip")
  396. # define parameters
  397. repeat_count = 2
  398. # apply dataset operations
  399. data1 = ds.LSUNDataset(DATA_DIR, num_samples=10)
  400. data2 = ds.LSUNDataset(DATA_DIR, num_samples=10)
  401. data1 = data1.repeat(repeat_count)
  402. # rename dataset2 for no conflict
  403. data2 = data2.rename(input_columns=["image", "label"], output_columns=["image1", "label1"])
  404. data3 = ds.zip((data1, data2))
  405. num_iter = 0
  406. # each data is a dictionary
  407. for item in data3.create_dict_iterator(num_epochs=1, output_numpy=True):
  408. # in this example, each dictionary has keys "image" and "label"
  409. logger.info("image is {}".format(item["image"]))
  410. logger.info("label is {}".format(item["label"]))
  411. num_iter += 1
  412. logger.info("Number of data in data1: {}".format(num_iter))
  413. assert num_iter == 4
  414. def test_lsun_exception():
  415. """
  416. Feature: LSUN
  417. Description: test error cases for LSUN
  418. Expectation: throw exception correctly
  419. """
  420. logger.info("Test lsun exception")
  421. error_msg_1 = "sampler and shuffle cannot be specified at the same time"
  422. with pytest.raises(RuntimeError, match=error_msg_1):
  423. ds.LSUNDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))
  424. error_msg_2 = "sampler and sharding cannot be specified at the same time"
  425. with pytest.raises(RuntimeError, match=error_msg_2):
  426. ds.LSUNDataset(DATA_DIR, sampler=ds.PKSampler(3), num_shards=2, shard_id=0)
  427. error_msg_3 = "num_shards is specified and currently requires shard_id as well"
  428. with pytest.raises(RuntimeError, match=error_msg_3):
  429. ds.LSUNDataset(DATA_DIR, num_shards=10)
  430. error_msg_4 = "shard_id is specified but num_shards is not"
  431. with pytest.raises(RuntimeError, match=error_msg_4):
  432. ds.LSUNDataset(DATA_DIR, shard_id=0)
  433. error_msg_5 = "Input shard_id is not within the required interval"
  434. with pytest.raises(ValueError, match=error_msg_5):
  435. ds.LSUNDataset(DATA_DIR, num_shards=5, shard_id=-1)
  436. with pytest.raises(ValueError, match=error_msg_5):
  437. ds.LSUNDataset(DATA_DIR, num_shards=5, shard_id=5)
  438. with pytest.raises(ValueError, match=error_msg_5):
  439. ds.LSUNDataset(DATA_DIR, num_shards=2, shard_id=5)
  440. error_msg_6 = "num_parallel_workers exceeds"
  441. with pytest.raises(ValueError, match=error_msg_6):
  442. ds.LSUNDataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
  443. with pytest.raises(ValueError, match=error_msg_6):
  444. ds.LSUNDataset(DATA_DIR, shuffle=False, num_parallel_workers=256)
  445. with pytest.raises(ValueError, match=error_msg_6):
  446. ds.LSUNDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)
  447. error_msg_7 = "Argument shard_id"
  448. with pytest.raises(TypeError, match=error_msg_7):
  449. ds.LSUNDataset(DATA_DIR, num_shards=2, shard_id="0")
  450. def test_lsun_exception_map():
  451. """
  452. Feature: LSUN
  453. Description: test error cases for LSUN
  454. Expectation: throw exception correctly
  455. """
  456. logger.info("Test lsun exception map")
  457. def exception_func(item):
  458. raise Exception("Error occur!")
  459. def exception_func2(image, label):
  460. raise Exception("Error occur!")
  461. try:
  462. data = ds.LSUNDataset(DATA_DIR)
  463. data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  464. for _ in data.__iter__():
  465. pass
  466. assert False
  467. except RuntimeError as e:
  468. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  469. try:
  470. data = ds.LSUNDataset(DATA_DIR)
  471. data = data.map(operations=exception_func2,
  472. input_columns=["image", "label"],
  473. output_columns=["image", "label", "label1"],
  474. column_order=["image", "label", "label1"],
  475. num_parallel_workers=1)
  476. for _ in data.__iter__():
  477. pass
  478. assert False
  479. except RuntimeError as e:
  480. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  481. try:
  482. data = ds.LSUNDataset(DATA_DIR)
  483. data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1)
  484. data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1)
  485. for _ in data.__iter__():
  486. pass
  487. assert False
  488. except RuntimeError as e:
  489. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  490. if __name__ == '__main__':
  491. test_lsun_basic()
  492. test_lsun_num_samples()
  493. test_sequential_sampler()
  494. test_random_sampler()
  495. test_distributed_sampler()
  496. test_pk_sampler()
  497. test_lsun_num_shards()
  498. test_lsun_shard_id()
  499. test_lsun_no_shuffle()
  500. test_lsun_extra_shuffle()
  501. test_lsun_decode()
  502. test_lsun_test_dataset()
  503. test_lsun_valid_dataset()
  504. test_lsun_train_dataset()
  505. test_lsun_all_dataset()
  506. test_lsun_classes()
  507. test_lsun_zip()
  508. test_lsun_exception()
  509. test_lsun_exception_map()