You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_wiki_text.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import mindspore.dataset as ds
  17. from mindspore import log as logger
  18. from util import config_get_set_num_parallel_workers, config_get_set_seed
  19. FILE_DIR = '../data/dataset/testWikiText'
  20. def test_wiki_text_dataset_test():
  21. """
  22. Feature: Test WikiText Dataset.
  23. Description: read test data from a single file.
  24. Expectation: the data is processed successfully.
  25. """
  26. data = ds.WikiTextDataset(FILE_DIR, usage='test', shuffle=False)
  27. count = 0
  28. test_content = [" no it was black friday ", " I am happy ", " finish math homework "]
  29. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  30. logger.info("{}".format(i["text"]))
  31. strs = i["text"].item().decode("utf8")
  32. assert strs == test_content[count]
  33. count += 1
  34. assert count == 3
  35. def test_wiki_text_dataset_train():
  36. """
  37. Feature: Test WikiText Dataset.
  38. Description: read train data from a single file.
  39. Expectation: the data is processed successfully.
  40. """
  41. data = ds.WikiTextDataset(FILE_DIR, usage='train', shuffle=False)
  42. count = 0
  43. train_content = [" go to china ", " I lova MindSpore ", " black white grapes "]
  44. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  45. logger.info("{}".format(i["text"]))
  46. strs = i["text"].item().decode("utf8")
  47. assert strs == train_content[count]
  48. count += 1
  49. assert count == 3
  50. def test_wiki_text_dataset_valid():
  51. """
  52. Feature: Test WikiText Dataset.
  53. Description: read valid data from a single file.
  54. Expectation: the data is processed successfully.
  55. """
  56. data = ds.WikiTextDataset(FILE_DIR, usage='valid', shuffle=False)
  57. count = 0
  58. valid_content = [" just ahead of them there was a huge fissure ", " zhejiang, china ", " MindSpore Ascend "]
  59. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  60. logger.info("{}".format(i["text"]))
  61. strs = i["text"].item().decode("utf8")
  62. assert strs == valid_content[count]
  63. count += 1
  64. assert count == 3
  65. def test_wiki_text_dataset_all_file():
  66. """
  67. Feature: Test WikiText Dataset.
  68. Description: read data from all files.
  69. Expectation: the data is processed successfully.
  70. """
  71. data = ds.WikiTextDataset(FILE_DIR, usage='all')
  72. count = 0
  73. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  74. logger.info("{}".format(i["text"]))
  75. count += 1
  76. assert count == 9
  77. def test_wiki_text_dataset_num_samples_none():
  78. """
  79. Feature: Test WikiText Dataset.
  80. Description: read data with no num_samples input.
  81. Expectation: the data is processed successfully.
  82. """
  83. # Do not provide a num_samples argument, so it would be None by default, which means all samples are read.
  84. data = ds.WikiTextDataset(FILE_DIR, usage='all')
  85. count = 0
  86. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  87. logger.info("{}".format(i["text"]))
  88. count += 1
  89. assert count == 9
  90. def test_wiki_text_dataset_shuffle_false_and_workers_4():
  91. """
  92. Feature: Test WikiText Dataset.
  93. Description: read data from a single file with shuffle is False and num_parallel_workers=4.
  94. Expectation: the data is processed successfully.
  95. """
  96. original_num_parallel_workers = config_get_set_num_parallel_workers(4)
  97. original_seed = config_get_set_seed(987)
  98. data = ds.WikiTextDataset(FILE_DIR, usage='all', shuffle=False)
  99. count = 0
  100. line = [" no it was black friday ",
  101. " go to china ",
  102. " just ahead of them there was a huge fissure ",
  103. " I am happy ",
  104. " I lova MindSpore ",
  105. " zhejiang, china ",
  106. " finish math homework ",
  107. " black white grapes ",
  108. " MindSpore Ascend "]
  109. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  110. strs = i["text"].item().decode("utf8")
  111. assert strs == line[count]
  112. count += 1
  113. assert count == 9
  114. # Restore configuration
  115. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  116. ds.config.set_seed(original_seed)
  117. def test_wiki_text_dataset_shuffle_false_and_workers_1():
  118. """
  119. Feature: Test WikiText Dataset.
  120. Description: Read data from a single file with shuffle is False and num_parallel_workers is 1.
  121. Expectation: the data is processed successfully.
  122. """
  123. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  124. original_seed = config_get_set_seed(987)
  125. data = ds.WikiTextDataset(FILE_DIR, usage='all', shuffle=False)
  126. count = 0
  127. line = [" no it was black friday ",
  128. " I am happy ",
  129. " finish math homework ",
  130. " go to china ",
  131. " I lova MindSpore ",
  132. " black white grapes ",
  133. " just ahead of them there was a huge fissure ",
  134. " zhejiang, china ",
  135. " MindSpore Ascend "]
  136. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  137. strs = i["text"].item().decode("utf8")
  138. assert strs == line[count]
  139. count += 1
  140. assert count == 9
  141. # Restore configuration
  142. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  143. ds.config.set_seed(original_seed)
  144. def test_wiki_text_dataset_shuffle_files_and_workers_4():
  145. """
  146. Feature: Test WikiText Dataset.
  147. Description: read data from a single file with shuffle is files and num_parallel_workers is 4.
  148. Expectation: the data is processed successfully.
  149. """
  150. original_num_parallel_workers = config_get_set_num_parallel_workers(4)
  151. original_seed = config_get_set_seed(135)
  152. data = ds.WikiTextDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.FILES)
  153. count = 0
  154. line = [" just ahead of them there was a huge fissure ",
  155. " go to china ",
  156. " no it was black friday ",
  157. " zhejiang, china ",
  158. " I lova MindSpore ",
  159. " I am happy ",
  160. " MindSpore Ascend ",
  161. " black white grapes ",
  162. " finish math homework "]
  163. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  164. strs = i["text"].item().decode("utf8")
  165. assert strs == line[count]
  166. count += 1
  167. assert count == 9
  168. # Restore configuration
  169. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  170. ds.config.set_seed(original_seed)
  171. def test_wiki_text_dataset_shuffle_files_and_workers_1():
  172. """
  173. Feature: Test WikiText Dataset.
  174. Description: read data from a single file with shuffle is files and num_parallel_workers is 1.
  175. Expectation: the data is processed successfully.
  176. """
  177. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  178. original_seed = config_get_set_seed(135)
  179. data = ds.WikiTextDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.FILES)
  180. count = 0
  181. line = [" just ahead of them there was a huge fissure ",
  182. " zhejiang, china ",
  183. " MindSpore Ascend ",
  184. " go to china ",
  185. " I lova MindSpore ",
  186. " black white grapes ",
  187. " no it was black friday ",
  188. " I am happy ",
  189. " finish math homework "]
  190. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  191. strs = i["text"].item().decode("utf8")
  192. assert strs == line[count]
  193. count += 1
  194. assert count == 9
  195. # Restore configuration
  196. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  197. ds.config.set_seed(original_seed)
  198. def test_wiki_text_dataset_shuffle_global4():
  199. """
  200. Feature: Test WikiText Dataset.
  201. Description: read data from a single file with shuffle is global.
  202. Expectation: the data is processed successfully.
  203. """
  204. original_num_parallel_workers = config_get_set_num_parallel_workers(4)
  205. original_seed = config_get_set_seed(246)
  206. data = ds.WikiTextDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.GLOBAL)
  207. count = 0
  208. line = [" MindSpore Ascend ",
  209. " go to china ",
  210. " I am happy ",
  211. " no it was black friday ",
  212. " just ahead of them there was a huge fissure ",
  213. " zhejiang, china ",
  214. " finish math homework ",
  215. " I lova MindSpore ",
  216. " black white grapes "]
  217. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  218. strs = i["text"].item().decode("utf8")
  219. assert strs == line[count]
  220. count += 1
  221. assert count == 9
  222. # Restore configuration
  223. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  224. ds.config.set_seed(original_seed)
  225. def test_wiki_text_dataset_shuffle_global1():
  226. """
  227. Feature: Test WikiText Dataset.
  228. Description: read data from a single file with shuffle is global.
  229. Expectation: the data is processed successfully.
  230. """
  231. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  232. original_seed = config_get_set_seed(246)
  233. data = ds.WikiTextDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.GLOBAL)
  234. count = 0
  235. line = [" MindSpore Ascend ",
  236. " go to china ",
  237. " I am happy ",
  238. " I lova MindSpore ",
  239. " black white grapes ",
  240. " finish math homework ",
  241. " zhejiang, china ",
  242. " no it was black friday ",
  243. " just ahead of them there was a huge fissure "]
  244. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  245. strs = i["text"].item().decode("utf8")
  246. assert strs == line[count]
  247. count += 1
  248. assert count == 9
  249. # Restore configuration
  250. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  251. ds.config.set_seed(original_seed)
  252. def test_wiki_text_dataset_num_samples():
  253. """
  254. Feature: Test WikiText Dataset.
  255. Description: Test num_samples.
  256. Expectation: the data is processed successfully.
  257. """
  258. data = ds.WikiTextDataset(FILE_DIR, usage='all', num_samples=2)
  259. count = 0
  260. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  261. count += 1
  262. assert count == 2
  263. def test_wiki_text_dataset_distribution():
  264. """
  265. Feature: Test WikiText Dataset.
  266. Description: read data from a single file.
  267. Expectation: the data is processed successfully.
  268. """
  269. data = ds.WikiTextDataset(FILE_DIR, usage='all', num_shards=2, shard_id=1)
  270. count = 0
  271. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  272. count += 1
  273. assert count == 5
  274. def test_wiki_text_dataset_repeat():
  275. """
  276. Feature: Test WikiText Dataset.
  277. Description: Test repeat.
  278. Expectation: the data is processed successfully.
  279. """
  280. data = ds.WikiTextDataset(FILE_DIR, usage='test', shuffle=False)
  281. data = data.repeat(3)
  282. count = 0
  283. line = [" no it was black friday ",
  284. " I am happy ",
  285. " finish math homework ",
  286. " no it was black friday ",
  287. " I am happy ",
  288. " finish math homework ",
  289. " no it was black friday ",
  290. " I am happy ",
  291. " finish math homework ",]
  292. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  293. strs = i["text"].item().decode("utf8")
  294. assert strs == line[count]
  295. count += 1
  296. assert count == 9
  297. def test_wiki_text_dataset_get_datasetsize():
  298. """
  299. Feature: Test WikiText Dataset.
  300. Description: Test get_datasetsize.
  301. Expectation: the data is processed successfully.
  302. """
  303. data = ds.WikiTextDataset(FILE_DIR, usage='test')
  304. size = data.get_dataset_size()
  305. assert size == 3
  306. def test_wiki_text_dataset_to_device():
  307. """
  308. Feature: Test WikiText Dataset.
  309. Description: Test to_device.
  310. Expectation: the data is processed successfully.
  311. """
  312. data = ds.WikiTextDataset(FILE_DIR, usage='test')
  313. data = data.to_device()
  314. data.send()
  315. def test_wiki_text_dataset_exceptions():
  316. """
  317. Feature: Test WikiText Dataset.
  318. Description: Test exceptions.
  319. Expectation: Exception thrown to be caught
  320. """
  321. with pytest.raises(ValueError) as error_info:
  322. _ = ds.WikiTextDataset(FILE_DIR, usage='test', num_samples=-1)
  323. assert "num_samples exceeds the boundary" in str(error_info.value)
  324. with pytest.raises(ValueError) as error_info:
  325. _ = ds.WikiTextDataset("does/not/exist/no.txt")
  326. assert str(error_info.value)
  327. with pytest.raises(ValueError) as error_info:
  328. _ = ds.WikiTextDataset("")
  329. assert str(error_info.value)
  330. def exception_func(item):
  331. raise Exception("Error occur!")
  332. with pytest.raises(RuntimeError) as error_info:
  333. data = ds.WikiTextDataset(FILE_DIR)
  334. data = data.map(operations=exception_func, input_columns=["text"], num_parallel_workers=1)
  335. for _ in data.__iter__():
  336. pass
  337. assert "map operation: [PyFunc] failed. The corresponding data files" in str(error_info.value)
  338. if __name__ == "__main__":
  339. test_wiki_text_dataset_test()
  340. test_wiki_text_dataset_train()
  341. test_wiki_text_dataset_valid()
  342. test_wiki_text_dataset_all_file()
  343. test_wiki_text_dataset_num_samples_none()
  344. test_wiki_text_dataset_shuffle_false_and_workers_4()
  345. test_wiki_text_dataset_shuffle_false_and_workers_1()
  346. test_wiki_text_dataset_shuffle_files_and_workers_4()
  347. test_wiki_text_dataset_shuffle_files_and_workers_1()
  348. test_wiki_text_dataset_shuffle_global4()
  349. test_wiki_text_dataset_shuffle_global1()
  350. test_wiki_text_dataset_num_samples()
  351. test_wiki_text_dataset_distribution()
  352. test_wiki_text_dataset_repeat()
  353. test_wiki_text_dataset_get_datasetsize()
  354. test_wiki_text_dataset_to_device()
  355. test_wiki_text_dataset_exceptions()