You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_penn_treebank.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import mindspore.dataset as ds
  17. from mindspore import log as logger
  18. from util import config_get_set_num_parallel_workers, config_get_set_seed
  19. FILE_DIR = '../data/dataset/testPennTreebank'
  20. def test_penn_treebank_dataset_one_file():
  21. """
  22. Feature: Test PennTreebank Dataset.
  23. Description: read data from a single file.
  24. Expectation: the data is processed successfully.
  25. """
  26. data = ds.PennTreebankDataset(FILE_DIR, usage='test')
  27. count = 0
  28. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  29. logger.info("{}".format(i["text"]))
  30. count += 1
  31. assert count == 3
  32. def test_penn_treebank_dataset_train():
  33. """
  34. Feature: Test PennTreebank Dataset.
  35. Description: read data from a single file.
  36. Expectation: the data is processed successfully.
  37. """
  38. data = ds.PennTreebankDataset(FILE_DIR, usage='train')
  39. count = 0
  40. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  41. logger.info("{}".format(i["text"]))
  42. count += 1
  43. assert count == 3
  44. def test_penn_treebank_dataset_valid():
  45. """
  46. Feature: Test PennTreebank Dataset.
  47. Description: read data from a single file.
  48. Expectation: the data is processed successfully.
  49. """
  50. data = ds.PennTreebankDataset(FILE_DIR, usage='valid')
  51. count = 0
  52. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  53. logger.info("{}".format(i["text"]))
  54. count += 1
  55. assert count == 3
  56. def test_penn_treebank_dataset_all_file():
  57. """
  58. Feature: Test PennTreebank Dataset.
  59. Description: read data from a single file.
  60. Expectation: the data is processed successfully.
  61. """
  62. data = ds.PennTreebankDataset(FILE_DIR, usage='all')
  63. count = 0
  64. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  65. logger.info("{}".format(i["text"]))
  66. count += 1
  67. assert count == 9
  68. def test_penn_treebank_dataset_num_samples_none():
  69. """
  70. Feature: Test PennTreebank Dataset.
  71. Description: read data with no num_samples input.
  72. Expectation: the data is processed successfully.
  73. """
  74. # Do not provide a num_samples argument, so it would be None by default
  75. data = ds.PennTreebankDataset(FILE_DIR, usage='all')
  76. count = 0
  77. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  78. logger.info("{}".format(i["text"]))
  79. count += 1
  80. assert count == 9
  81. def test_penn_treebank_dataset_shuffle_false4():
  82. """
  83. Feature: Test PennTreebank Dataset.
  84. Description: read data from a single file with shulle is false.
  85. Expectation: the data is processed successfully.
  86. """
  87. original_num_parallel_workers = config_get_set_num_parallel_workers(4)
  88. original_seed = config_get_set_seed(987)
  89. data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=False)
  90. count = 0
  91. line = [" no it was black friday ",
  92. " does the bank charge a fee for setting up the account ",
  93. " just ahead of them there was a huge fissure ",
  94. " clash twits poetry formulate flip loyalty splash ",
  95. " <unk> the wardrobe was very small in our room ",
  96. " <unk> <unk> the proportion of female workers in this company <unk> <unk> ",
  97. " you pay less for the supermaket's own brands ",
  98. " black white grapes ",
  99. " everyone in our football team is fuming "]
  100. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  101. strs = i["text"].item().decode("utf8")
  102. assert strs == line[count]
  103. count += 1
  104. assert count == 9
  105. # Restore configuration
  106. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  107. ds.config.set_seed(original_seed)
  108. def test_penn_treebank_dataset_shuffle_false1():
  109. """
  110. Feature: Test PennTreebank Dataset.
  111. Description: read data from a single file with shulle is false.
  112. Expectation: the data is processed successfully.
  113. """
  114. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  115. original_seed = config_get_set_seed(987)
  116. data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=False)
  117. count = 0
  118. line = [" no it was black friday ",
  119. " clash twits poetry formulate flip loyalty splash ",
  120. " you pay less for the supermaket's own brands ",
  121. " does the bank charge a fee for setting up the account ",
  122. " <unk> the wardrobe was very small in our room ",
  123. " black white grapes ",
  124. " just ahead of them there was a huge fissure ",
  125. " <unk> <unk> the proportion of female workers in this company <unk> <unk> ",
  126. " everyone in our football team is fuming "]
  127. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  128. strs = i["text"].item().decode("utf8")
  129. assert strs == line[count]
  130. count += 1
  131. assert count == 9
  132. # Restore configuration
  133. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  134. ds.config.set_seed(original_seed)
  135. def test_penn_treebank_dataset_shuffle_files4():
  136. """
  137. Feature: Test PennTreebank Dataset.
  138. Description: read data from a single file with shulle is files.
  139. Expectation: the data is processed successfully.
  140. """
  141. original_num_parallel_workers = config_get_set_num_parallel_workers(4)
  142. original_seed = config_get_set_seed(135)
  143. data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.FILES)
  144. count = 0
  145. line = [" just ahead of them there was a huge fissure ",
  146. " does the bank charge a fee for setting up the account ",
  147. " no it was black friday ",
  148. " <unk> <unk> the proportion of female workers in this company <unk> <unk> ",
  149. " <unk> the wardrobe was very small in our room ",
  150. " clash twits poetry formulate flip loyalty splash ",
  151. " everyone in our football team is fuming ",
  152. " black white grapes ",
  153. " you pay less for the supermaket's own brands "]
  154. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  155. strs = i["text"].item().decode("utf8")
  156. assert strs == line[count]
  157. count += 1
  158. assert count == 9
  159. # Restore configuration
  160. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  161. ds.config.set_seed(original_seed)
  162. def test_penn_treebank_dataset_shuffle_files1():
  163. """
  164. Feature: Test PennTreebank Dataset.
  165. Description: read data from a single file with shulle is files.
  166. Expectation: the data is processed successfully.
  167. """
  168. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  169. original_seed = config_get_set_seed(135)
  170. data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.FILES)
  171. count = 0
  172. line = [" just ahead of them there was a huge fissure ",
  173. " <unk> <unk> the proportion of female workers in this company <unk> <unk> ",
  174. " everyone in our football team is fuming ",
  175. " does the bank charge a fee for setting up the account ",
  176. " <unk> the wardrobe was very small in our room ",
  177. " black white grapes ",
  178. " no it was black friday ",
  179. " clash twits poetry formulate flip loyalty splash ",
  180. " you pay less for the supermaket's own brands "]
  181. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  182. strs = i["text"].item().decode("utf8")
  183. assert strs == line[count]
  184. count += 1
  185. assert count == 9
  186. # Restore configuration
  187. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  188. ds.config.set_seed(original_seed)
  189. def test_penn_treebank_dataset_shuffle_global4():
  190. """
  191. Feature: Test PennTreebank Dataset.
  192. Description: read data from a single file with shulle is global.
  193. Expectation: the data is processed successfully.
  194. """
  195. original_num_parallel_workers = config_get_set_num_parallel_workers(4)
  196. original_seed = config_get_set_seed(246)
  197. data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.GLOBAL)
  198. count = 0
  199. line = [" everyone in our football team is fuming ",
  200. " does the bank charge a fee for setting up the account ",
  201. " clash twits poetry formulate flip loyalty splash ",
  202. " no it was black friday ",
  203. " just ahead of them there was a huge fissure ",
  204. " <unk> <unk> the proportion of female workers in this company <unk> <unk> ",
  205. " you pay less for the supermaket's own brands ",
  206. " <unk> the wardrobe was very small in our room ",
  207. " black white grapes "]
  208. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  209. strs = i["text"].item().decode("utf8")
  210. assert strs == line[count]
  211. count += 1
  212. assert count == 9
  213. # Restore configuration
  214. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  215. ds.config.set_seed(original_seed)
  216. def test_penn_treebank_dataset_shuffle_global1():
  217. """
  218. Feature: Test PennTreebank Dataset.
  219. Description: read data from a single file with shulle is global.
  220. Expectation: the data is processed successfully.
  221. """
  222. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  223. original_seed = config_get_set_seed(246)
  224. data = ds.PennTreebankDataset(FILE_DIR, usage='all', shuffle=ds.Shuffle.GLOBAL)
  225. count = 0
  226. line = [" everyone in our football team is fuming ",
  227. " does the bank charge a fee for setting up the account ",
  228. " clash twits poetry formulate flip loyalty splash ",
  229. " <unk> the wardrobe was very small in our room ",
  230. " black white grapes ",
  231. " you pay less for the supermaket's own brands ",
  232. " <unk> <unk> the proportion of female workers in this company <unk> <unk> ",
  233. " no it was black friday ",
  234. " just ahead of them there was a huge fissure "]
  235. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  236. strs = i["text"].item().decode("utf8")
  237. assert strs == line[count]
  238. count += 1
  239. assert count == 9
  240. # Restore configuration
  241. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  242. ds.config.set_seed(original_seed)
  243. def test_penn_treebank_dataset_num_samples():
  244. """
  245. Feature: Test PennTreebank Dataset.
  246. Description: Test num_samples.
  247. Expectation: the data is processed successfully.
  248. """
  249. data = ds.PennTreebankDataset(FILE_DIR, usage='all', num_samples=2)
  250. count = 0
  251. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  252. count += 1
  253. assert count == 2
  254. def test_penn_treebank_dataset_distribution():
  255. """
  256. Feature: Test PennTreebank Dataset.
  257. Description: read data from a single file.
  258. Expectation: the data is processed successfully.
  259. """
  260. data = ds.PennTreebankDataset(FILE_DIR, usage='all', num_shards=2, shard_id=1)
  261. count = 0
  262. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  263. count += 1
  264. assert count == 5
  265. def test_penn_treebank_dataset_repeat():
  266. """
  267. Feature: Test PennTreebank Dataset.
  268. Description: Test repeat.
  269. Expectation: the data is processed successfully.
  270. """
  271. data = ds.PennTreebankDataset(FILE_DIR, usage='test', shuffle=False)
  272. data = data.repeat(3)
  273. count = 0
  274. line = [" no it was black friday ",
  275. " clash twits poetry formulate flip loyalty splash ",
  276. " you pay less for the supermaket's own brands ",
  277. " no it was black friday ",
  278. " clash twits poetry formulate flip loyalty splash ",
  279. " you pay less for the supermaket's own brands ",
  280. " no it was black friday ",
  281. " clash twits poetry formulate flip loyalty splash ",
  282. " you pay less for the supermaket's own brands ",]
  283. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  284. strs = i["text"].item().decode("utf8")
  285. assert strs == line[count]
  286. count += 1
  287. assert count == 9
  288. def test_penn_treebank_dataset_get_datasetsize():
  289. """
  290. Feature: Test PennTreebank Dataset.
  291. Description: Test get_datasetsize.
  292. Expectation: the data is processed successfully.
  293. """
  294. data = ds.PennTreebankDataset(FILE_DIR, usage='test')
  295. size = data.get_dataset_size()
  296. assert size == 3
  297. def test_penn_treebank_dataset_to_device():
  298. """
  299. Feature: Test PennTreebank Dataset.
  300. Description: Test to_device.
  301. Expectation: the data is processed successfully.
  302. """
  303. data = ds.PennTreebankDataset(FILE_DIR, usage='test')
  304. data = data.to_device()
  305. data.send()
  306. def test_penn_treebank_dataset_exceptions():
  307. """
  308. Feature: Test PennTreebank Dataset.
  309. Description: Test exceptions.
  310. Expectation: Exception thrown to be caught
  311. """
  312. with pytest.raises(ValueError) as error_info:
  313. _ = ds.PennTreebankDataset(FILE_DIR, usage='test', num_samples=-1)
  314. assert "num_samples exceeds the boundary" in str(error_info.value)
  315. with pytest.raises(ValueError) as error_info:
  316. _ = ds.PennTreebankDataset("does/not/exist/no.txt")
  317. assert str(error_info.value)
  318. with pytest.raises(ValueError) as error_info:
  319. _ = ds.PennTreebankDataset("")
  320. assert str(error_info.value)
  321. def exception_func(item):
  322. raise Exception("Error occur!")
  323. with pytest.raises(RuntimeError) as error_info:
  324. data = ds.PennTreebankDataset(FILE_DIR)
  325. data = data.map(operations=exception_func, input_columns=["text"], num_parallel_workers=1)
  326. for _ in data.__iter__():
  327. pass
  328. assert "map operation: [PyFunc] failed. The corresponding data files" in str(error_info.value)
  329. if __name__ == "__main__":
  330. test_penn_treebank_dataset_one_file()
  331. test_penn_treebank_dataset_train()
  332. test_penn_treebank_dataset_valid()
  333. test_penn_treebank_dataset_all_file()
  334. test_penn_treebank_dataset_num_samples_none()
  335. test_penn_treebank_dataset_shuffle_false4()
  336. test_penn_treebank_dataset_shuffle_false1()
  337. test_penn_treebank_dataset_shuffle_files4()
  338. test_penn_treebank_dataset_shuffle_files1()
  339. test_penn_treebank_dataset_shuffle_global4()
  340. test_penn_treebank_dataset_shuffle_global1()
  341. test_penn_treebank_dataset_num_samples()
  342. test_penn_treebank_dataset_distribution()
  343. test_penn_treebank_dataset_repeat()
  344. test_penn_treebank_dataset_get_datasetsize()
  345. test_penn_treebank_dataset_to_device()
  346. test_penn_treebank_dataset_exceptions()