You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_imdb.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import mindspore.dataset as ds
  17. from mindspore import log as logger
  18. DATA_DIR = "../data/dataset/testIMDBDataset"
  19. def test_imdb_basic():
  20. """
  21. Feature: Test IMDB Dataset.
  22. Description: read data from all file.
  23. Expectation: the data is processed successfully.
  24. """
  25. logger.info("Test Case basic")
  26. # define parameters
  27. repeat_count = 1
  28. # apply dataset operations
  29. data1 = ds.IMDBDataset(DATA_DIR, shuffle=False)
  30. data1 = data1.repeat(repeat_count)
  31. # Verify dataset size
  32. data1_size = data1.get_dataset_size()
  33. logger.info("dataset size is: {}".format(data1_size))
  34. assert data1_size == 8
  35. content = ["train_pos_0.txt", "train_pos_1.txt", "train_neg_0.txt", "train_neg_1.txt",
  36. "test_pos_0.txt", "test_pos_1.txt", "test_neg_0.txt", "test_neg_1.txt"]
  37. label = [1, 1, 0, 0, 1, 1, 0, 0]
  38. num_iter = 0
  39. for index, item in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
  40. # each data is a dictionary
  41. # in this example, each dictionary has keys "text" and "label"
  42. strs = item["text"].item().decode("utf8")
  43. logger.info("text is {}".format(strs))
  44. logger.info("label is {}".format(item["label"]))
  45. assert strs == content[index]
  46. assert label[index] == int(item["label"])
  47. num_iter += 1
  48. logger.info("Number of data in data1: {}".format(num_iter))
  49. assert num_iter == 8
  50. def test_imdb_test():
  51. """
  52. Feature: Test IMDB Dataset.
  53. Description: read data from test file.
  54. Expectation: the data is processed successfully.
  55. """
  56. logger.info("Test Case test")
  57. # define parameters
  58. repeat_count = 1
  59. usage = "test"
  60. # apply dataset operations
  61. data1 = ds.IMDBDataset(DATA_DIR, usage=usage, shuffle=False)
  62. data1 = data1.repeat(repeat_count)
  63. # Verify dataset size
  64. data1_size = data1.get_dataset_size()
  65. logger.info("dataset size is: {}".format(data1_size))
  66. assert data1_size == 4
  67. content = ["test_pos_0.txt", "test_pos_1.txt", "test_neg_0.txt", "test_neg_1.txt"]
  68. label = [1, 1, 0, 0]
  69. num_iter = 0
  70. for index, item in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
  71. # each data is a dictionary
  72. # in this example, each dictionary has keys "text" and "label"
  73. strs = item["text"].item().decode("utf8")
  74. logger.info("text is {}".format(strs))
  75. logger.info("label is {}".format(item["label"]))
  76. assert strs == content[index]
  77. assert label[index] == int(item["label"])
  78. num_iter += 1
  79. logger.info("Number of data in data1: {}".format(num_iter))
  80. assert num_iter == 4
  81. def test_imdb_train():
  82. """
  83. Feature: Test IMDB Dataset.
  84. Description: read data from train file.
  85. Expectation: the data is processed successfully.
  86. """
  87. logger.info("Test Case train")
  88. # define parameters
  89. repeat_count = 1
  90. usage = "train"
  91. # apply dataset operations
  92. data1 = ds.IMDBDataset(DATA_DIR, usage=usage, shuffle=False)
  93. data1 = data1.repeat(repeat_count)
  94. # Verify dataset size
  95. data1_size = data1.get_dataset_size()
  96. logger.info("dataset size is: {}".format(data1_size))
  97. assert data1_size == 4
  98. content = ["train_pos_0.txt", "train_pos_1.txt", "train_neg_0.txt", "train_neg_1.txt"]
  99. label = [1, 1, 0, 0]
  100. num_iter = 0
  101. for index, item in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
  102. # each data is a dictionary
  103. # in this example, each dictionary has keys "text" and "label"
  104. strs = item["text"].item().decode("utf8")
  105. logger.info("text is {}".format(strs))
  106. logger.info("label is {}".format(item["label"]))
  107. assert strs == content[index]
  108. assert label[index] == int(item["label"])
  109. num_iter += 1
  110. logger.info("Number of data in data1: {}".format(num_iter))
  111. assert num_iter == 4
  112. def test_imdb_num_samples():
  113. """
  114. Feature: Test IMDB Dataset.
  115. Description: read data from all file with num_samples=10 and num_parallel_workers=2.
  116. Expectation: the data is processed successfully.
  117. """
  118. logger.info("Test Case numSamples")
  119. # define parameters
  120. repeat_count = 1
  121. # apply dataset operations
  122. data1 = ds.IMDBDataset(DATA_DIR, num_samples=6, num_parallel_workers=2)
  123. data1 = data1.repeat(repeat_count)
  124. # Verify dataset size
  125. data1_size = data1.get_dataset_size()
  126. logger.info("dataset size is: {}".format(data1_size))
  127. assert data1_size == 6
  128. num_iter = 0
  129. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  130. # in this example, each dictionary has keys "text" and "label"
  131. logger.info("text is {}".format(item["text"].item().decode("utf8")))
  132. logger.info("label is {}".format(item["label"]))
  133. num_iter += 1
  134. logger.info("Number of data in data1: {}".format(num_iter))
  135. assert num_iter == 6
  136. random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
  137. data1 = ds.IMDBDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
  138. num_iter = 0
  139. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  140. num_iter += 1
  141. assert num_iter == 3
  142. random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
  143. data1 = ds.IMDBDataset(DATA_DIR, num_parallel_workers=2, sampler=random_sampler)
  144. num_iter = 0
  145. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  146. num_iter += 1
  147. assert num_iter == 3
  148. def test_imdb_num_shards():
  149. """
  150. Feature: Test IMDB Dataset.
  151. Description: read data from all file with num_shards=2 and shard_id=1.
  152. Expectation: the data is processed successfully.
  153. """
  154. logger.info("Test Case numShards")
  155. # define parameters
  156. repeat_count = 1
  157. # apply dataset operations
  158. data1 = ds.IMDBDataset(DATA_DIR, num_shards=2, shard_id=1)
  159. data1 = data1.repeat(repeat_count)
  160. # Verify dataset size
  161. data1_size = data1.get_dataset_size()
  162. logger.info("dataset size is: {}".format(data1_size))
  163. assert data1_size == 4
  164. num_iter = 0
  165. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  166. # in this example, each dictionary has keys "text" and "label"
  167. logger.info("text is {}".format(item["text"].item().decode("utf8")))
  168. logger.info("label is {}".format(item["label"]))
  169. num_iter += 1
  170. logger.info("Number of data in data1: {}".format(num_iter))
  171. assert num_iter == 4
  172. def test_imdb_shard_id():
  173. """
  174. Feature: Test IMDB Dataset.
  175. Description: read data from all file with num_shards=4 and shard_id=1.
  176. Expectation: the data is processed successfully.
  177. """
  178. logger.info("Test Case withShardID")
  179. # define parameters
  180. repeat_count = 1
  181. # apply dataset operations
  182. data1 = ds.IMDBDataset(DATA_DIR, num_shards=2, shard_id=0)
  183. data1 = data1.repeat(repeat_count)
  184. # Verify dataset size
  185. data1_size = data1.get_dataset_size()
  186. logger.info("dataset size is: {}".format(data1_size))
  187. assert data1_size == 4
  188. num_iter = 0
  189. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  190. # in this example, each dictionary has keys "text" and "label"
  191. logger.info("text is {}".format(item["text"].item().decode("utf8")))
  192. logger.info("label is {}".format(item["label"]))
  193. num_iter += 1
  194. logger.info("Number of data in data1: {}".format(num_iter))
  195. assert num_iter == 4
  196. def test_imdb_no_shuffle():
  197. """
  198. Feature: Test IMDB Dataset.
  199. Description: read data from all file with shuffle=False.
  200. Expectation: the data is processed successfully.
  201. """
  202. logger.info("Test Case noShuffle")
  203. # define parameters
  204. repeat_count = 1
  205. # apply dataset operations
  206. data1 = ds.IMDBDataset(DATA_DIR, shuffle=False)
  207. data1 = data1.repeat(repeat_count)
  208. # Verify dataset size
  209. data1_size = data1.get_dataset_size()
  210. logger.info("dataset size is: {}".format(data1_size))
  211. assert data1_size == 8
  212. num_iter = 0
  213. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  214. # in this example, each dictionary has keys "text" and "label"
  215. logger.info("text is {}".format(item["text"].item().decode("utf8")))
  216. logger.info("label is {}".format(item["label"]))
  217. num_iter += 1
  218. logger.info("Number of data in data1: {}".format(num_iter))
  219. assert num_iter == 8
  220. def test_imdb_true_shuffle():
  221. """
  222. Feature: Test IMDB Dataset.
  223. Description: read data from all file with shuffle=True.
  224. Expectation: the data is processed successfully.
  225. """
  226. logger.info("Test Case extraShuffle")
  227. # define parameters
  228. repeat_count = 2
  229. # apply dataset operations
  230. data1 = ds.IMDBDataset(DATA_DIR, shuffle=True)
  231. data1 = data1.repeat(repeat_count)
  232. # Verify dataset size
  233. data1_size = data1.get_dataset_size()
  234. logger.info("dataset size is: {}".format(data1_size))
  235. assert data1_size == 16
  236. num_iter = 0
  237. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  238. # in this example, each dictionary has keys "text" and "label"
  239. logger.info("text is {}".format(item["text"].item().decode("utf8")))
  240. logger.info("label is {}".format(item["label"]))
  241. num_iter += 1
  242. logger.info("Number of data in data1: {}".format(num_iter))
  243. assert num_iter == 16
  244. def test_random_sampler():
  245. """
  246. Feature: Test IMDB Dataset.
  247. Description: read data from all file with sampler=ds.RandomSampler().
  248. Expectation: the data is processed successfully.
  249. """
  250. logger.info("Test Case RandomSampler")
  251. # define parameters
  252. repeat_count = 1
  253. # apply dataset operations
  254. sampler = ds.RandomSampler()
  255. data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
  256. data1 = data1.repeat(repeat_count)
  257. # Verify dataset size
  258. data1_size = data1.get_dataset_size()
  259. logger.info("dataset size is: {}".format(data1_size))
  260. assert data1_size == 8
  261. num_iter = 0
  262. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  263. # in this example, each dictionary has keys "text" and "label"
  264. logger.info("text is {}".format(item["text"].item().decode("utf8")))
  265. logger.info("label is {}".format(item["label"]))
  266. num_iter += 1
  267. logger.info("Number of data in data1: {}".format(num_iter))
  268. assert num_iter == 8
  269. def test_distributed_sampler():
  270. """
  271. Feature: Test IMDB Dataset.
  272. Description: read data from all file with sampler=ds.DistributedSampler().
  273. Expectation: the data is processed successfully.
  274. """
  275. logger.info("Test Case DistributedSampler")
  276. # define parameters
  277. repeat_count = 1
  278. # apply dataset operations
  279. sampler = ds.DistributedSampler(4, 1)
  280. data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
  281. data1 = data1.repeat(repeat_count)
  282. # Verify dataset size
  283. data1_size = data1.get_dataset_size()
  284. logger.info("dataset size is: {}".format(data1_size))
  285. assert data1_size == 2
  286. num_iter = 0
  287. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  288. # in this example, each dictionary has keys "text" and "label"
  289. logger.info("text is {}".format(item["text"].item().decode("utf8")))
  290. logger.info("label is {}".format(item["label"]))
  291. num_iter += 1
  292. logger.info("Number of data in data1: {}".format(num_iter))
  293. assert num_iter == 2
  294. def test_pk_sampler():
  295. """
  296. Feature: Test IMDB Dataset.
  297. Description: read data from all file with sampler=ds.PKSampler().
  298. Expectation: the data is processed successfully.
  299. """
  300. logger.info("Test Case PKSampler")
  301. # define parameters
  302. repeat_count = 1
  303. # apply dataset operations
  304. sampler = ds.PKSampler(3)
  305. data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
  306. data1 = data1.repeat(repeat_count)
  307. # Verify dataset size
  308. data1_size = data1.get_dataset_size()
  309. logger.info("dataset size is: {}".format(data1_size))
  310. assert data1_size == 6
  311. num_iter = 0
  312. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  313. # in this example, each dictionary has keys "text" and "label"
  314. logger.info("text is {}".format(item["text"].item().decode("utf8")))
  315. logger.info("label is {}".format(item["label"]))
  316. num_iter += 1
  317. logger.info("Number of data in data1: {}".format(num_iter))
  318. assert num_iter == 6
  319. def test_subset_random_sampler():
  320. """
  321. Feature: Test IMDB Dataset.
  322. Description: read data from all file with sampler=ds.SubsetRandomSampler().
  323. Expectation: the data is processed successfully.
  324. """
  325. logger.info("Test Case SubsetRandomSampler")
  326. # define parameters
  327. repeat_count = 1
  328. # apply dataset operations
  329. indices = [0, 3, 1, 2, 5, 4]
  330. sampler = ds.SubsetRandomSampler(indices)
  331. data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
  332. data1 = data1.repeat(repeat_count)
  333. # Verify dataset size
  334. data1_size = data1.get_dataset_size()
  335. logger.info("dataset size is: {}".format(data1_size))
  336. assert data1_size == 6
  337. num_iter = 0
  338. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  339. # in this example, each dictionary has keys "text" and "label"
  340. logger.info("text is {}".format(item["text"].item().decode("utf8")))
  341. logger.info("label is {}".format(item["label"]))
  342. num_iter += 1
  343. logger.info("Number of data in data1: {}".format(num_iter))
  344. assert num_iter == 6
  345. def test_weighted_random_sampler():
  346. """
  347. Feature: Test IMDB Dataset.
  348. Description: read data from all file with sampler=ds.WeightedRandomSampler().
  349. Expectation: the data is processed successfully.
  350. """
  351. logger.info("Test Case WeightedRandomSampler")
  352. # define parameters
  353. repeat_count = 1
  354. # apply dataset operations
  355. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05]
  356. sampler = ds.WeightedRandomSampler(weights, 6)
  357. data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
  358. data1 = data1.repeat(repeat_count)
  359. # Verify dataset size
  360. data1_size = data1.get_dataset_size()
  361. logger.info("dataset size is: {}".format(data1_size))
  362. assert data1_size == 6
  363. num_iter = 0
  364. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  365. # in this example, each dictionary has keys "text" and "label"
  366. logger.info("text is {}".format(item["text"].item().decode("utf8")))
  367. logger.info("label is {}".format(item["label"]))
  368. num_iter += 1
  369. logger.info("Number of data in data1: {}".format(num_iter))
  370. assert num_iter == 6
  371. def test_weighted_random_sampler_exception():
  372. """
  373. Feature: Test IMDB Dataset.
  374. Description: read data from all file with random sampler exception.
  375. Expectation: the data is processed successfully.
  376. """
  377. logger.info("Test error cases for WeightedRandomSampler")
  378. error_msg_1 = "type of weights element must be number"
  379. with pytest.raises(TypeError, match=error_msg_1):
  380. weights = ""
  381. ds.WeightedRandomSampler(weights)
  382. error_msg_2 = "type of weights element must be number"
  383. with pytest.raises(TypeError, match=error_msg_2):
  384. weights = (0.9, 0.8, 1.1)
  385. ds.WeightedRandomSampler(weights)
  386. error_msg_3 = "WeightedRandomSampler: weights vector must not be empty"
  387. with pytest.raises(RuntimeError, match=error_msg_3):
  388. weights = []
  389. sampler = ds.WeightedRandomSampler(weights)
  390. sampler.parse()
  391. error_msg_4 = "WeightedRandomSampler: weights vector must not contain negative numbers, got: "
  392. with pytest.raises(RuntimeError, match=error_msg_4):
  393. weights = [1.0, 0.1, 0.02, 0.3, -0.4]
  394. sampler = ds.WeightedRandomSampler(weights)
  395. sampler.parse()
  396. error_msg_5 = "WeightedRandomSampler: elements of weights vector must not be all zero"
  397. with pytest.raises(RuntimeError, match=error_msg_5):
  398. weights = [0, 0, 0, 0, 0]
  399. sampler = ds.WeightedRandomSampler(weights)
  400. sampler.parse()
  401. def test_chained_sampler_with_random_sequential_repeat():
  402. """
  403. Feature: Test IMDB Dataset.
  404. Description: read data from all file with Random and Sequential, with repeat.
  405. Expectation: the data is processed successfully.
  406. """
  407. logger.info("Test Case Chained Sampler - Random and Sequential, with repeat")
  408. # Create chained sampler, random and sequential
  409. sampler = ds.RandomSampler()
  410. child_sampler = ds.SequentialSampler()
  411. sampler.add_child(child_sampler)
  412. # Create IMDBDataset with sampler
  413. data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
  414. data1 = data1.repeat(count=3)
  415. # Verify dataset size
  416. data1_size = data1.get_dataset_size()
  417. logger.info("dataset size is: {}".format(data1_size))
  418. assert data1_size == 24
  419. # Verify number of iterations
  420. num_iter = 0
  421. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  422. # in this example, each dictionary has keys "text" and "label"
  423. logger.info("text is {}".format(item["text"].item().decode("utf8")))
  424. logger.info("label is {}".format(item["label"]))
  425. num_iter += 1
  426. logger.info("Number of data in data1: {}".format(num_iter))
  427. assert num_iter == 24
  428. def test_chained_sampler_with_distribute_random_batch_then_repeat():
  429. """
  430. Feature: Test IMDB Dataset.
  431. Description: read data from all file with Distributed and Random, with batch then repeat.
  432. Expectation: the data is processed successfully.
  433. """
  434. logger.info("Test Case Chained Sampler - Distributed and Random, with batch then repeat")
  435. # Create chained sampler, distributed and random
  436. sampler = ds.DistributedSampler(num_shards=4, shard_id=3)
  437. child_sampler = ds.RandomSampler()
  438. sampler.add_child(child_sampler)
  439. # Create IMDBDataset with sampler
  440. data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
  441. data1 = data1.batch(batch_size=5, drop_remainder=True)
  442. data1 = data1.repeat(count=3)
  443. # Verify dataset size
  444. data1_size = data1.get_dataset_size()
  445. logger.info("dataset size is: {}".format(data1_size))
  446. assert data1_size == 0
  447. # Verify number of iterations
  448. num_iter = 0
  449. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  450. # in this example, each dictionary has keys "text" and "label"
  451. logger.info("text is {}".format(item["text"].item().decode("utf8")))
  452. logger.info("label is {}".format(item["label"]))
  453. num_iter += 1
  454. logger.info("Number of data in data1: {}".format(num_iter))
  455. # Note: Each of the 4 shards has 44/4=11 samples
  456. # Note: Number of iterations is (11/5 = 2) * 3 = 6
  457. assert num_iter == 0
  458. def test_chained_sampler_with_weighted_random_pk_sampler():
  459. """
  460. Feature: Test IMDB Dataset.
  461. Description: read data from all file with WeightedRandom and PKSampler.
  462. Expectation: the data is processed successfully.
  463. """
  464. logger.info("Test Case Chained Sampler - WeightedRandom and PKSampler")
  465. # Create chained sampler, WeightedRandom and PKSampler
  466. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05]
  467. sampler = ds.WeightedRandomSampler(weights=weights, num_samples=6)
  468. child_sampler = ds.PKSampler(num_val=3) # Number of elements per class is 3 (and there are 4 classes)
  469. sampler.add_child(child_sampler)
  470. # Create IMDBDataset with sampler
  471. data1 = ds.IMDBDataset(DATA_DIR, sampler=sampler)
  472. # Verify dataset size
  473. data1_size = data1.get_dataset_size()
  474. logger.info("dataset size is: {}".format(data1_size))
  475. assert data1_size == 6
  476. # Verify number of iterations
  477. num_iter = 0
  478. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  479. # in this example, each dictionary has keys "text" and "label"
  480. logger.info("text is {}".format(item["text"].item().decode("utf8")))
  481. logger.info("label is {}".format(item["label"]))
  482. num_iter += 1
  483. logger.info("Number of data in data1: {}".format(num_iter))
  484. # Note: WeightedRandomSampler produces 12 samples
  485. # Note: Child PKSampler produces 12 samples
  486. assert num_iter == 6
  487. def test_imdb_rename():
  488. """
  489. Feature: Test IMDB Dataset.
  490. Description: read data from all file with rename.
  491. Expectation: the data is processed successfully.
  492. """
  493. logger.info("Test Case rename")
  494. # define parameters
  495. repeat_count = 1
  496. # apply dataset operations
  497. data1 = ds.IMDBDataset(DATA_DIR, num_samples=8)
  498. data1 = data1.repeat(repeat_count)
  499. num_iter = 0
  500. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  501. # in this example, each dictionary has keys "text" and "label"
  502. logger.info("text is {}".format(item["text"].item().decode("utf8")))
  503. logger.info("label is {}".format(item["label"]))
  504. num_iter += 1
  505. logger.info("Number of data in data1: {}".format(num_iter))
  506. assert num_iter == 8
  507. data1 = data1.rename(input_columns=["text"], output_columns="text2")
  508. num_iter = 0
  509. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  510. # in this example, each dictionary has keys "text" and "label"
  511. logger.info("text is {}".format(item["text2"]))
  512. logger.info("label is {}".format(item["label"]))
  513. num_iter += 1
  514. logger.info("Number of data in data1: {}".format(num_iter))
  515. assert num_iter == 8
  516. def test_imdb_zip():
  517. """
  518. Feature: Test IMDB Dataset.
  519. Description: read data from all file with zip.
  520. Expectation: the data is processed successfully.
  521. """
  522. logger.info("Test Case zip")
  523. # define parameters
  524. repeat_count = 2
  525. # apply dataset operations
  526. data1 = ds.IMDBDataset(DATA_DIR, num_samples=4)
  527. data2 = ds.IMDBDataset(DATA_DIR, num_samples=4)
  528. data1 = data1.repeat(repeat_count)
  529. # rename dataset2 for no conflict
  530. data2 = data2.rename(input_columns=["text", "label"], output_columns=["text1", "label1"])
  531. data3 = ds.zip((data1, data2))
  532. num_iter = 0
  533. for item in data3.create_dict_iterator(num_epochs=1, output_numpy=True): # each data is a dictionary
  534. # in this example, each dictionary has keys "text" and "label"
  535. logger.info("text is {}".format(item["text"].item().decode("utf8")))
  536. logger.info("label is {}".format(item["label"]))
  537. num_iter += 1
  538. logger.info("Number of data in data1: {}".format(num_iter))
  539. assert num_iter == 4
  540. def test_imdb_exception():
  541. """
  542. Feature: Test IMDB Dataset.
  543. Description: read data from all file with exception.
  544. Expectation: the data is processed successfully.
  545. """
  546. logger.info("Test imdb exception")
  547. def exception_func(item):
  548. raise Exception("Error occur!")
  549. def exception_func2(text, label):
  550. raise Exception("Error occur!")
  551. try:
  552. data = ds.IMDBDataset(DATA_DIR)
  553. data = data.map(operations=exception_func, input_columns=["text"], num_parallel_workers=1)
  554. for _ in data.__iter__():
  555. pass
  556. assert False
  557. except RuntimeError as e:
  558. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  559. try:
  560. data = ds.IMDBDataset(DATA_DIR)
  561. data = data.map(operations=exception_func2, input_columns=["text", "label"],
  562. output_columns=["text", "label", "label1"],
  563. column_order=["text", "label", "label1"], num_parallel_workers=1)
  564. for _ in data.__iter__():
  565. pass
  566. assert False
  567. except RuntimeError as e:
  568. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  569. data_dir_invalid = "../data/dataset/IMDBDATASET"
  570. try:
  571. data = ds.IMDBDataset(data_dir_invalid)
  572. for _ in data.__iter__():
  573. pass
  574. assert False
  575. except ValueError as e:
  576. assert "does not exist or is not a directory or permission denied" in str(e)
  577. if __name__ == '__main__':
  578. test_imdb_basic()
  579. test_imdb_test()
  580. test_imdb_train()
  581. test_imdb_num_samples()
  582. test_random_sampler()
  583. test_distributed_sampler()
  584. test_pk_sampler()
  585. test_subset_random_sampler()
  586. test_weighted_random_sampler()
  587. test_weighted_random_sampler_exception()
  588. test_chained_sampler_with_random_sequential_repeat()
  589. test_chained_sampler_with_distribute_random_batch_then_repeat()
  590. test_chained_sampler_with_weighted_random_pk_sampler()
  591. test_imdb_num_shards()
  592. test_imdb_shard_id()
  593. test_imdb_no_shuffle()
  594. test_imdb_true_shuffle()
  595. test_imdb_rename()
  596. test_imdb_zip()
  597. test_imdb_exception()