You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_udposop.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import mindspore.dataset as ds
  17. from mindspore import log as logger
  18. from util import config_get_set_num_parallel_workers, config_get_set_seed
  19. DATA_DIR = '../data/dataset/testUDPOSDataset/'
  20. def test_udpos_dataset_one_file():
  21. """
  22. Feature: Test UDPOS Dataset.
  23. Description: read one file
  24. Expectation: throw number of data in a file
  25. """
  26. data = ds.UDPOSDataset(DATA_DIR, usage="test", shuffle=False)
  27. count = 0
  28. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  29. logger.info("{}".format(i["word"]))
  30. count += 1
  31. assert count == 1
  32. def test_udpos_dataset_all_file():
  33. """
  34. Feature: Test UDPOS Dataset.
  35. Description: read all file
  36. Expectation: throw number of data in all file
  37. """
  38. data = ds.UDPOSDataset(DATA_DIR, usage="all", shuffle=False)
  39. count = 0
  40. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  41. logger.info("{}".format(i["word"]))
  42. count += 1
  43. assert count == 6
  44. def test_udpos_dataset_shuffle_false_four_parallel():
  45. """
  46. Feature: Test UDPOS Dataset.
  47. Description: set up four parallel
  48. Expectation: throw data
  49. """
  50. original_num_parallel_workers = config_get_set_num_parallel_workers(4)
  51. original_seed = config_get_set_seed(987)
  52. data = ds.UDPOSDataset(DATA_DIR, usage="all", shuffle=False)
  53. count = 0
  54. numword = 6
  55. line = ["From", "The", "Abed", "Come", "The", "Std",
  56. "What", "Like", "Good", "Mom", "Iike", "Good",
  57. "Abed", "...", "Zoom", "...", "Abed", "From",
  58. "Psg", "Bus", "Ori", "The", "Abed", "The",
  59. "...", "The", "ken", "Ori", "...", "Respect",
  60. "Bus", "Nine", "Job", "Mom", "Abed", "From"]
  61. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  62. for j in range(numword):
  63. strs = i["word"][j].item().decode("utf8")
  64. assert strs == line[count*6+j]
  65. count += 1
  66. assert count == 6
  67. # Restore configuration
  68. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  69. ds.config.set_seed(original_seed)
  70. def test_udpos_dataset_shuffle_false_one_parallel():
  71. """
  72. Feature: Test UDPOS Dataset.
  73. Description: no parallelism set
  74. Expectation: throw data
  75. """
  76. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  77. original_seed = config_get_set_seed(987)
  78. data = ds.UDPOSDataset(DATA_DIR, usage="all", shuffle=False)
  79. count = 0
  80. numword = 6
  81. line = ["From", "The", "Abed", "Come", "The", "Std",
  82. "Psg", "Bus", "Ori", "The", "Abed", "The",
  83. "Bus", "Nine", "Job", "Mom", "Abed", "From",
  84. "What", "Like", "Good", "Mom", "Iike", "Good",
  85. "Abed", "...", "Zoom", "...", "Abed", "From",
  86. "...", "The", "ken", "Ori", "...", "Respect"]
  87. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  88. for j in range(numword):
  89. strs = i["word"][j].item().decode("utf8")
  90. assert strs == line[count*6+j]
  91. count += 1
  92. assert count == 6
  93. # Restore configuration
  94. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  95. ds.config.set_seed(original_seed)
  96. def test_udpos_dataset_shuffle_files_four_parallel():
  97. """
  98. Feature: Test UDPOS Dataset.
  99. Description: set four parallel and file Disorder
  100. Expectation: throw data
  101. """
  102. original_num_parallel_workers = config_get_set_num_parallel_workers(4)
  103. original_seed = config_get_set_seed(135)
  104. data = ds.UDPOSDataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.FILES)
  105. count = 0
  106. numword = 6
  107. line = ["Abed", "...", "Zoom", "...", "Abed", "From",
  108. "What", "Like", "Good", "Mom", "Iike", "Good",
  109. "From", "The", "Abed", "Come", "The", "Std",
  110. "...", "The", "ken", "Ori", "...", "Respect",
  111. "Psg", "Bus", "Ori", "The", "Abed", "The",
  112. "Bus", "Nine", "Job", "Mom", "Abed", "From"]
  113. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  114. for j in range(numword):
  115. strs = i["word"][j].item().decode("utf8")
  116. assert strs == line[count*6+j]
  117. count += 1
  118. assert count == 6
  119. # Restore configuration
  120. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  121. ds.config.set_seed(original_seed)
  122. def test_udpos_dataset_shuffle_files_one_parallel():
  123. """
  124. Feature: Test UDPOS Dataset.
  125. Description: set no parallelism and file Disorder
  126. Expectation: throw data
  127. """
  128. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  129. original_seed = config_get_set_seed(135)
  130. data = ds.UDPOSDataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.FILES)
  131. count = 0
  132. numword = 6
  133. line = ["Abed", "...", "Zoom", "...", "Abed", "From",
  134. "...", "The", "ken", "Ori", "...", "Respect",
  135. "What", "Like", "Good", "Mom", "Iike", "Good",
  136. "From", "The", "Abed", "Come", "The", "Std",
  137. "Psg", "Bus", "Ori", "The", "Abed", "The",
  138. "Bus", "Nine", "Job", "Mom", "Abed", "From"]
  139. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  140. for j in range(numword):
  141. strs = i["word"][j].item().decode("utf8")
  142. assert strs == line[count*6+j]
  143. count += 1
  144. assert count == 6
  145. # Restore configuration
  146. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  147. ds.config.set_seed(original_seed)
  148. def test_udpos_dataset_shuffle_global_four_parallel():
  149. """
  150. Feature: Test UDPOS Dataset.
  151. Description: set four parallel and all Disorder
  152. Expectation: throw data
  153. """
  154. original_num_parallel_workers = config_get_set_num_parallel_workers(4)
  155. original_seed = config_get_set_seed(246)
  156. data = ds.UDPOSDataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.GLOBAL)
  157. count = 0
  158. numword = 6
  159. line = ["Bus", "Nine", "Job", "Mom", "Abed", "From",
  160. "Abed", "...", "Zoom", "...", "Abed", "From",
  161. "From", "The", "Abed", "Come", "The", "Std",
  162. "Psg", "Bus", "Ori", "The", "Abed", "The",
  163. "What", "Like", "Good", "Mom", "Iike", "Good",
  164. "...", "The", "ken", "Ori", "...", "Respect"]
  165. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  166. for j in range(numword):
  167. strs = i["word"][j].item().decode("utf8")
  168. assert strs == line[count*6+j]
  169. count += 1
  170. assert count == 6
  171. # Restore configuration
  172. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  173. ds.config.set_seed(original_seed)
  174. def test_udpos_dataset_shuffle_global_one_parallel():
  175. """
  176. Feature: Test UDPOS Dataset.
  177. Description: set no parallelism and all Disorder
  178. Expectation: throw data
  179. """
  180. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  181. original_seed = config_get_set_seed(246)
  182. data = ds.UDPOSDataset(DATA_DIR, usage="all", shuffle=ds.Shuffle.GLOBAL)
  183. count = 0
  184. numword = 6
  185. line = ["...", "The", "ken", "Ori", "...", "Respect",
  186. "Psg", "Bus", "Ori", "The", "Abed", "The",
  187. "From", "The", "Abed", "Come", "The", "Std",
  188. "Bus", "Nine", "Job", "Mom", "Abed", "From",
  189. "What", "Like", "Good", "Mom", "Iike", "Good",
  190. "Abed", "...", "Zoom", "...", "Abed", "From"]
  191. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  192. for j in range(numword):
  193. strs = i["word"][j].item().decode("utf8")
  194. assert strs == line[count*6+j]
  195. count += 1
  196. assert count == 6
  197. # Restore configuration
  198. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  199. ds.config.set_seed(original_seed)
  200. def test_udpos_dataset_num_samples():
  201. """
  202. Feature: Test UDPOS Dataset.
  203. Description: read one file
  204. Expectation: throw number of file
  205. """
  206. data = ds.UDPOSDataset(DATA_DIR, usage="test", shuffle=False, num_samples=2)
  207. count = 0
  208. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  209. count += 1
  210. assert count == 1
  211. def test_udpos_dataset_distribution():
  212. """
  213. Feature: Test UDPOS Dataset.
  214. Description: read one file
  215. Expectation: throw number of file
  216. """
  217. data = ds.UDPOSDataset(DATA_DIR, usage="test", shuffle=False, num_shards=2, shard_id=1)
  218. count = 0
  219. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  220. count += 1
  221. assert count == 1
  222. def test_udpos_dataset_repeat():
  223. """
  224. Feature: Test UDPOS Dataset.
  225. Description: repeat read data
  226. Expectation: throw data
  227. """
  228. data = ds.UDPOSDataset(DATA_DIR, usage="test", shuffle=False)
  229. data = data.repeat(3)
  230. count = 0
  231. numword = 6
  232. line = ["What", "Like", "Good", "Mom", "Iike", "Good",
  233. "What", "Like", "Good", "Mom", "Iike", "Good",
  234. "What", "Like", "Good", "Mom", "Iike", "Good"]
  235. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  236. for j in range(numword):
  237. strs = i["word"][j].item().decode("utf8")
  238. assert strs == line[count*6+j]
  239. count += 1
  240. assert count == 3
  241. def test_udpos_dataset_get_datasetsize():
  242. """
  243. Feature: Test UDPOS Dataset.
  244. Description: repeat read data
  245. Expectation: throw data
  246. """
  247. data = ds.UDPOSDataset(DATA_DIR, usage="test", shuffle=False)
  248. size = data.get_dataset_size()
  249. assert size == 6
  250. def test_udpos_dataset_to_device():
  251. """
  252. Feature: Test UDPOS Dataset.
  253. Description: transfer data from CPU to other devices
  254. Expectation: send
  255. """
  256. data = ds.UDPOSDataset(DATA_DIR, usage="test", shuffle=False)
  257. data = data.to_device()
  258. data.send()
  259. def test_udpos_dataset_exceptions():
  260. """
  261. Feature: Test UDPOS Dataset.
  262. Description: send error when error occur
  263. Expectation: send error
  264. """
  265. with pytest.raises(ValueError) as error_info:
  266. _ = ds.UDPOSDataset(DATA_DIR, usage="test", num_samples=-1)
  267. assert "num_samples exceeds the boundary" in str(error_info.value)
  268. with pytest.raises(ValueError) as error_info:
  269. _ = ds.UDPOSDataset("NotExistFile", usage="test")
  270. assert "The folder NotExistFile does not exist or is not a directory or permission denied!" in str(error_info.value)
  271. with pytest.raises(ValueError) as error_info:
  272. _ = ds.TextFileDataset("")
  273. assert "The following patterns did not match any files" in str(error_info.value)
  274. def exception_func(item):
  275. raise Exception("Error occur!")
  276. with pytest.raises(RuntimeError) as error_info:
  277. data = data = ds.UDPOSDataset(DATA_DIR, usage="test", shuffle=False)
  278. data = data.map(operations=exception_func, input_columns=["word"], num_parallel_workers=1)
  279. for _ in data.__iter__():
  280. pass
  281. assert "map operation: [PyFunc] failed. The corresponding data files" in str(error_info.value)
  282. if __name__ == "__main__":
  283. test_udpos_dataset_one_file()
  284. test_udpos_dataset_all_file()
  285. test_udpos_dataset_shuffle_false_four_parallel()
  286. test_udpos_dataset_shuffle_false_one_parallel()
  287. test_udpos_dataset_shuffle_files_one_parallel()
  288. test_udpos_dataset_shuffle_files_four_parallel()
  289. test_udpos_dataset_shuffle_global_four_parallel()
  290. test_udpos_dataset_shuffle_global_one_parallel()
  291. test_udpos_dataset_num_samples()
  292. test_udpos_dataset_distribution()
  293. test_udpos_dataset_repeat()
  294. test_udpos_dataset_get_datasetsize()
  295. test_udpos_dataset_to_device()
  296. test_udpos_dataset_exceptions()