You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_conll2000.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import mindspore.dataset as ds
  17. from mindspore import log as logger
  18. from util import config_get_set_num_parallel_workers, config_get_set_seed
  19. DATA_DIR = '../data/dataset/testCoNLL2000Dataset'
  20. def test_conll2000_dataset_one_file():
  21. """
  22. Feature: CoNLL2000ChunkingDataset.
  23. Description: test param check of CoNLL2000ChunkingDataset.
  24. Expectation: throw correct error and message.
  25. """
  26. data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False)
  27. count = 0
  28. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  29. logger.info("{}".format(i["word"]))
  30. count += 1
  31. assert count == 2
  32. def test_conll2000_dataset_all_file():
  33. """
  34. Feature: CoNLL2000ChunkingDataset.
  35. Description: test param check of CoNLL2000ChunkingDataset.
  36. Expectation: throw correct error and message.
  37. """
  38. data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=False)
  39. count = 0
  40. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  41. logger.info("{}".format(i["word"]))
  42. count += 1
  43. assert count == 5
  44. def test_conll2000_dataset_num_samples_none():
  45. """
  46. Feature: CoNLL2000ChunkingDataset
  47. Description: test param check of CoNLL2000ChunkingDataset
  48. Expectation: throw correct error and message
  49. """
  50. # Do not provide a num_samples argument, so it would be None by default
  51. data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False)
  52. count = 0
  53. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  54. logger.info("{}".format(i["word"]))
  55. count += 1
  56. assert count == 2
  57. def test_conll2000_dataset_shuffle_false_num_parallel_workers_4():
  58. """
  59. Feature: CoNLL2000ChunkingDataset.
  60. Description: test param check of CoNLL2000ChunkingDataset.
  61. Expectation: throw correct error and message.
  62. """
  63. original_num_parallel_workers = config_get_set_num_parallel_workers(4)
  64. original_seed = config_get_set_seed(987)
  65. data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=False)
  66. count = 0
  67. numword = 5
  68. line = ["He", "reckons", "the", "current", "account", ".",
  69. "Challenge", "of", "the", "August", "month", ".",
  70. "The", "1.8", "billion", "in", "September", ".",
  71. "Her", "'s", "chancellor", "at", "Lawson", ".",
  72. "To", "economists", ",", "foreign", "exchange", "."]
  73. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  74. for j in range(numword):
  75. strs = i["word"][j].item().decode("utf8")
  76. assert strs == line[count*6+j]
  77. count += 1
  78. assert count == 5
  79. # Restore configuration
  80. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  81. ds.config.set_seed(original_seed)
  82. def test_conll2000_dataset_shuffle_false_num_parallel_workers_1():
  83. """
  84. Feature: CoNLL2000ChunkingDataset.
  85. Description: test param check of CoNLL2000ChunkingDataset.
  86. Expectation: throw correct error and message.
  87. """
  88. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  89. original_seed = config_get_set_seed(987)
  90. data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=False)
  91. count = 0
  92. numword = 6
  93. line = ["He", "reckons", "the", "current", "account", ".",
  94. "The", "1.8", "billion", "in", "September", ".",
  95. "Challenge", "of", "the", "August", "month", ".",
  96. "Her", "'s", "chancellor", "at", "Lawson", ".",
  97. "To", "economists", ",", "foreign", "exchange", "."]
  98. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  99. for j in range(numword):
  100. strs = i["word"][j].item().decode("utf8")
  101. assert strs == line[count*6+j]
  102. count += 1
  103. assert count == 5
  104. # Restore configuration
  105. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  106. ds.config.set_seed(original_seed)
  107. def test_conll2000_dataset_shuffle_files_num_parallel_workers_4():
  108. """
  109. Feature: CoNLL2000ChunkingDataset.
  110. Description: test param check of CoNLL2000ChunkingDataset.
  111. Expectation: throw correct error and message.
  112. """
  113. original_num_parallel_workers = config_get_set_num_parallel_workers(4)
  114. original_seed = config_get_set_seed(135)
  115. data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.FILES)
  116. count = 0
  117. numword = 6
  118. line = ["He", "reckons", "the", "current", "account", ".",
  119. "Challenge", "of", "the", "August", "month", ".",
  120. "The", "1.8", "billion", "in", "September", ".",
  121. "Her", "'s", "chancellor", "at", "Lawson", ".",
  122. "To", "economists", ",", "foreign", "exchange", "."]
  123. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  124. for j in range(numword):
  125. strs = i["word"][j].item().decode("utf8")
  126. assert strs == line[count*6+j]
  127. count += 1
  128. assert count == 5
  129. # Restore configuration
  130. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  131. ds.config.set_seed(original_seed)
  132. def test_conll2000_dataset_shuffle_files_num_parallel_workers_1():
  133. """
  134. Feature: CoNLL2000ChunkingDataset.
  135. Description: test param check of CoNLL2000ChunkingDataset.
  136. Expectation: throw correct error and message.
  137. """
  138. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  139. original_seed = config_get_set_seed(135)
  140. data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.FILES)
  141. count = 0
  142. numword = 6
  143. line = ["He", "reckons", "the", "current", "account", ".",
  144. "The", "1.8", "billion", "in", "September", ".",
  145. "Challenge", "of", "the", "August", "month", ".",
  146. "Her", "'s", "chancellor", "at", "Lawson", ".",
  147. "To", "economists", ",", "foreign", "exchange", "."]
  148. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  149. for j in range(numword):
  150. strs = i["word"][j].item().decode("utf8")
  151. assert strs == line[count*6+j]
  152. count += 1
  153. assert count == 5
  154. # Restore configuration
  155. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  156. ds.config.set_seed(original_seed)
  157. def test_conll2000_dataset_shuffle_global_num_parallel_workers_4():
  158. """
  159. Feature: CoNLL2000ChunkingDataset.
  160. Description: test param check of CoNLL2000ChunkingDataset.
  161. Expectation: throw correct error and message.
  162. """
  163. original_num_parallel_workers = config_get_set_num_parallel_workers(4)
  164. original_seed = config_get_set_seed(246)
  165. data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.GLOBAL)
  166. count = 0
  167. numword = 6
  168. line = ["Challenge", "of", "the", "August", "month", ".",
  169. "To", "economists", ",", "foreign", "exchange", ".",
  170. "Her", "'s", "chancellor", "at", "Lawson", ".",
  171. "He", "reckons", "the", "current", "account", ".",
  172. "The", "1.8", "billion", "in", "September", "."]
  173. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  174. for j in range(numword):
  175. strs = i["word"][j].item().decode("utf8")
  176. assert strs == line[count*6+j]
  177. count += 1
  178. assert count == 5
  179. # Restore configuration
  180. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  181. ds.config.set_seed(original_seed)
  182. def test_conll2000_dataset_shuffle_global_num_parallel_workers_1():
  183. """
  184. Feature: CoNLL2000ChunkingDataset.
  185. Description: test param check of CoNLL2000ChunkingDataset.
  186. Expectation: throw correct error and message.
  187. """
  188. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  189. original_seed = config_get_set_seed(246)
  190. data = ds.CoNLL2000Dataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.GLOBAL)
  191. count = 0
  192. numword = 6
  193. line = ["Challenge", "of", "the", "August", "month", ".",
  194. "The", "1.8", "billion", "in", "September", ".",
  195. "To", "economists", ",", "foreign", "exchange", ".",
  196. "Her", "'s", "chancellor", "at", "Lawson", ".",
  197. "He", "reckons", "the", "current", "account", "."]
  198. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  199. for j in range(numword):
  200. strs = i["word"][j].item().decode("utf8")
  201. assert strs == line[count*6+j]
  202. count += 1
  203. assert count == 5
  204. # Restore configuration
  205. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  206. ds.config.set_seed(original_seed)
  207. def test_conll2000_dataset_num_samples():
  208. """
  209. Feature: CoNLL2000ChunkingDataset.
  210. Description: test param check of CoNLL2000ChunkingDataset.
  211. Expectation: throw correct error and message.
  212. """
  213. data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False, num_samples=2)
  214. count = 0
  215. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  216. count += 1
  217. assert count == 2
  218. def test_conll2000_dataset_distribution():
  219. """
  220. Feature: CoNLL2000ChunkingDataset.
  221. Description: test param check of CoNLL2000ChunkingDataset.
  222. Expectation: throw correct error and message.
  223. """
  224. data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False, num_shards=2, shard_id=1)
  225. count = 0
  226. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  227. count += 1
  228. assert count == 1
  229. def test_conll2000_dataset_repeat():
  230. """
  231. Feature: CoNLL2000ChunkingDataset.
  232. Description: test param check of CoNLL2000ChunkingDataset.
  233. Expectation: throw correct error and message.
  234. """
  235. data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False)
  236. data = data.repeat(3)
  237. count = 0
  238. numword = 6
  239. line = ["He", "reckons", "the", "current", "account", ".",
  240. "The", "1.8", "billion", "in", "September", ".",
  241. "He", "reckons", "the", "current", "account", ".",
  242. "The", "1.8", "billion", "in", "September", ".",
  243. "He", "reckons", "the", "current", "account", ".",
  244. "The", "1.8", "billion", "in", "September", ".",]
  245. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  246. for j in range(numword):
  247. strs = i["word"][j].item().decode("utf8")
  248. assert strs == line[count*6+j]
  249. count += 1
  250. assert count == 6
  251. def test_conll2000_dataset_get_datasetsize():
  252. """
  253. Feature: CoNLL2000ChunkingDataset.
  254. Description: test param check of CoNLL2000ChunkingDataset.
  255. Expectation: throw correct error and message.
  256. """
  257. data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False)
  258. size = data.get_dataset_size()
  259. assert size == 12
  260. def test_conll2000_dataset_to_device():
  261. """
  262. Feature: CoNLL2000ChunkingDataset.
  263. Description: test param check of CoNLL2000ChunkingDataset.
  264. Expectation: throw correct error and message.
  265. """
  266. data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False)
  267. data = data.to_device()
  268. data.send()
  269. def test_conll2000_dataset_exceptions():
  270. """
  271. Feature: CoNLL2000ChunkingDataset.
  272. Description: test param check of CoNLL2000ChunkingDataset.
  273. Expectation: throw correct error and message.
  274. """
  275. with pytest.raises(ValueError) as error_info:
  276. _ = ds.CoNLL2000Dataset(DATA_DIR, usage="test", num_samples=-1)
  277. assert "num_samples exceeds the boundary" in str(error_info.value)
  278. with pytest.raises(ValueError) as error_info:
  279. _ = ds.CoNLL2000Dataset("NotExistFile", usage="test")
  280. assert "The folder NotExistFile does not exist or is not a directory or permission denied!" in str(error_info.value)
  281. with pytest.raises(ValueError) as error_info:
  282. _ = ds.TextFileDataset("")
  283. assert "The following patterns did not match any files" in str(error_info.value)
  284. def exception_func(item):
  285. raise Exception("Error occur!")
  286. with pytest.raises(RuntimeError) as error_info:
  287. data = data = ds.CoNLL2000Dataset(DATA_DIR, usage="test", shuffle=False)
  288. data = data.map(operations=exception_func, input_columns=["word"], num_parallel_workers=1)
  289. for _ in data.__iter__():
  290. pass
  291. assert "map operation: [PyFunc] failed. The corresponding data files" in str(error_info.value)
  292. if __name__ == "__main__":
  293. test_conll2000_dataset_one_file()
  294. test_conll2000_dataset_all_file()
  295. test_conll2000_dataset_num_samples_none()
  296. test_conll2000_dataset_shuffle_false_num_parallel_workers_4()
  297. test_conll2000_dataset_shuffle_false_num_parallel_workers_1()
  298. test_conll2000_dataset_shuffle_files_num_parallel_workers_4()
  299. test_conll2000_dataset_shuffle_files_num_parallel_workers_1()
  300. test_conll2000_dataset_shuffle_global_num_parallel_workers_4()
  301. test_conll2000_dataset_shuffle_global_num_parallel_workers_1()
  302. test_conll2000_dataset_num_samples()
  303. test_conll2000_dataset_distribution()
  304. test_conll2000_dataset_repeat()
  305. test_conll2000_dataset_get_datasetsize()
  306. test_conll2000_dataset_to_device()
  307. test_conll2000_dataset_exceptions()