You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_enwik9.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import mindspore.dataset as ds
  17. from mindspore import log as logger
  18. from util import config_get_set_num_parallel_workers, config_get_set_seed
  19. DATA_FILE = "../data/dataset/testEnWik9Dataset"
  20. def test_enwik9_total_rows_dataset_num_samples_none():
  21. """
  22. Feature: EnWik9Dataset
  23. Description: test the function while param num_samples = 0
  24. Expectation: the number of samples is 13
  25. """
  26. # Do not provide a num_samples argument, so it would be None by default.
  27. data = ds.EnWik9Dataset(DATA_FILE)
  28. count = 0
  29. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  30. logger.info("{}".format(i["text"]))
  31. count += 1
  32. assert count == 13
  33. def test_enwik9_total_rows_dataset_shuffle_false_parallel_worker_two():
  34. """
  35. Feature: EnWik9Dataset
  36. Description: test the function while param shuffle = False
  37. Expectation: the samples is ordered
  38. """
  39. original_num_parallel_workers = config_get_set_num_parallel_workers(2)
  40. original_seed = config_get_set_seed(987)
  41. data = ds.EnWik9Dataset(DATA_FILE, shuffle=False)
  42. count = 0
  43. line = [" <page>",
  44. " <title>MindSpore</title>",
  45. " <id>1</id>",
  46. " <revision>",
  47. " <id>234</id>",
  48. " <timestamp>2020-01-01T00:00:00Z</timestamp>",
  49. " <contributor>",
  50. " <username>MS</username>",
  51. " <id>567</id>",
  52. " </contributor>",
  53. " <text xml:space=\"preserve\">666</text>",
  54. " </revision>",
  55. " </page>"]
  56. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  57. strs = i["text"].item().decode("utf8")
  58. assert strs == line[count]
  59. count += 1
  60. assert count == 13
  61. # Restore configuration.
  62. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  63. ds.config.set_seed(original_seed)
  64. def test_enwik9_total_rows_dataset_shuffle_false_parallel_worker_one():
  65. """
  66. Feature: EnWik9Dataset
  67. Description: test the function while param shuffle = False
  68. Expectation: the samples is ordered
  69. """
  70. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  71. original_seed = config_get_set_seed(987)
  72. data = ds.EnWik9Dataset(DATA_FILE, shuffle=False)
  73. count = 0
  74. line = [" <page>",
  75. " <title>MindSpore</title>",
  76. " <id>1</id>",
  77. " <revision>",
  78. " <id>234</id>",
  79. " <timestamp>2020-01-01T00:00:00Z</timestamp>",
  80. " <contributor>",
  81. " <username>MS</username>",
  82. " <id>567</id>",
  83. " </contributor>",
  84. " <text xml:space=\"preserve\">666</text>",
  85. " </revision>",
  86. " </page>"]
  87. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  88. strs = i["text"].item().decode("utf8")
  89. assert strs == line[count]
  90. count += 1
  91. assert count == 13
  92. # Restore configuration.
  93. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  94. ds.config.set_seed(original_seed)
  95. def test_enwik9_total_rows_dataset_shuffle_true_parallel_worker_two():
  96. """
  97. Feature: EnWik9Dataset
  98. Description: test the function while param shuffle = True
  99. Expectation: the samples is disorder
  100. """
  101. original_num_parallel_workers = config_get_set_num_parallel_workers(2)
  102. original_seed = config_get_set_seed(135)
  103. data = ds.EnWik9Dataset(DATA_FILE, shuffle=True)
  104. count = 0
  105. line = [" <username>MS</username>",
  106. " <title>MindSpore</title>",
  107. " <id>234</id>",
  108. " </revision>",
  109. " </contributor>",
  110. " <revision>",
  111. " <id>567</id>",
  112. " <timestamp>2020-01-01T00:00:00Z</timestamp>",
  113. " <id>1</id>",
  114. " </page>",
  115. " <page>",
  116. " <text xml:space=\"preserve\">666</text>",
  117. " <contributor>"]
  118. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  119. strs = i["text"].item().decode("utf8")
  120. assert strs == line[count]
  121. count += 1
  122. assert count == 13
  123. # Restore configuration.
  124. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  125. ds.config.set_seed(original_seed)
  126. def test_enwik9_total_rows_dataset_shuffle_true_parallel_worker_one():
  127. """
  128. Feature: EnWik9Dataset
  129. Description: test the function while param shuffle = True
  130. Expectation: the samples is disorder
  131. """
  132. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  133. original_seed = config_get_set_seed(135)
  134. data = ds.EnWik9Dataset(DATA_FILE, shuffle=True)
  135. count = 0
  136. line = [" <username>MS</username>",
  137. " <title>MindSpore</title>",
  138. " <id>234</id>",
  139. " </revision>",
  140. " </contributor>",
  141. " <revision>",
  142. " <id>567</id>",
  143. " <timestamp>2020-01-01T00:00:00Z</timestamp>",
  144. " <id>1</id>",
  145. " </page>",
  146. " <page>",
  147. " <text xml:space=\"preserve\">666</text>",
  148. " <contributor>"]
  149. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  150. strs = i["text"].item().decode("utf8")
  151. assert strs == line[count]
  152. count += 1
  153. assert count == 13
  154. # Restore configuration.
  155. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  156. ds.config.set_seed(original_seed)
  157. def test_enwik9_dataset_num_samples():
  158. """
  159. Feature: EnWik9Dataset
  160. Description: test param num_samples, while it = 2
  161. Expectation: the number of samples = 2
  162. """
  163. data = ds.EnWik9Dataset(DATA_FILE, num_samples=2)
  164. count = 0
  165. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  166. count += 1
  167. assert count == 2
  168. def test_enwik9_dataset_distribution():
  169. """
  170. Feature: EnWik9Dataset
  171. Description: test distribution of the dataset
  172. Expectation: count = 7
  173. """
  174. data = ds.EnWik9Dataset(DATA_FILE, num_shards=2, shard_id=1)
  175. count = 0
  176. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  177. count += 1
  178. assert count == 7
  179. def test_enwik9_total_rows_dataset_repeat():
  180. """
  181. Feature: EnWik9Dataset
  182. Description: test the function whie the samples are repeat
  183. Expectation: count = 26
  184. """
  185. data = ds.EnWik9Dataset(DATA_FILE, shuffle=False)
  186. data = data.repeat(2)
  187. count = 0
  188. line = [" <page>",
  189. " <title>MindSpore</title>",
  190. " <id>1</id>",
  191. " <revision>",
  192. " <id>234</id>",
  193. " <timestamp>2020-01-01T00:00:00Z</timestamp>",
  194. " <contributor>",
  195. " <username>MS</username>",
  196. " <id>567</id>",
  197. " </contributor>",
  198. " <text xml:space=\"preserve\">666</text>",
  199. " </revision>",
  200. " </page>",
  201. " <page>",
  202. " <title>MindSpore</title>",
  203. " <id>1</id>",
  204. " <revision>",
  205. " <id>234</id>",
  206. " <timestamp>2020-01-01T00:00:00Z</timestamp>",
  207. " <contributor>",
  208. " <username>MS</username>",
  209. " <id>567</id>",
  210. " </contributor>",
  211. " <text xml:space=\"preserve\">666</text>",
  212. " </revision>",
  213. " </page>"]
  214. for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  215. strs = i["text"].item().decode("utf8")
  216. assert strs == line[count]
  217. count += 1
  218. assert count == 26
  219. def test_enwik9_total_rows_dataset_get_datasetsize():
  220. """
  221. Feature: EnWik9Dataset
  222. Description: test the function, get_dataset_size()
  223. Expectation: size = 13
  224. """
  225. data = ds.EnWik9Dataset(DATA_FILE)
  226. size = data.get_dataset_size()
  227. assert size == 13
  228. def test_enwik9_total_rows_dataset_to_device():
  229. """
  230. Feature: EnWik9Dataset
  231. Description: test the function, to_device()
  232. Expectation: size = 13
  233. """
  234. data = ds.EnWik9Dataset(DATA_FILE, shuffle=False)
  235. data = data.to_device()
  236. data.send()
  237. def test_enwik9_dataset_exceptions():
  238. """
  239. Feature: EnWik9Dataset
  240. Description: test the errors which appear possibly
  241. Expectation: the errors are expected correctly
  242. """
  243. with pytest.raises(ValueError) as error_info:
  244. _ = ds.EnWik9Dataset("does/not/exist/")
  245. assert "does not exist or is not a directory or permission denied" in str(error_info.value)
  246. with pytest.raises(ValueError) as error_info:
  247. _ = ds.EnWik9Dataset("")
  248. assert "The folder does not exist or is not a directory or permission denied" in str(error_info.value)
  249. def exception_func(item):
  250. raise Exception("Error occur!")
  251. with pytest.raises(RuntimeError) as error_info:
  252. data = ds.EnWik9Dataset(DATA_FILE)
  253. data = data.map(operations=exception_func, input_columns=["text"], num_parallel_workers=1)
  254. for _ in data.__iter__():
  255. pass
  256. assert "map operation: [PyFunc] failed. The corresponding data files" in str(error_info.value)
  257. if __name__ == "__main__":
  258. test_enwik9_total_rows_dataset_num_samples_none()
  259. test_enwik9_total_rows_dataset_shuffle_false_parallel_worker_two()
  260. test_enwik9_total_rows_dataset_shuffle_false_parallel_worker_one()
  261. test_enwik9_total_rows_dataset_shuffle_true_parallel_worker_two()
  262. test_enwik9_total_rows_dataset_shuffle_true_parallel_worker_one()
  263. test_enwik9_dataset_num_samples()
  264. test_enwik9_dataset_distribution()
  265. test_enwik9_total_rows_dataset_repeat()
  266. test_enwik9_total_rows_dataset_get_datasetsize()
  267. test_enwik9_total_rows_dataset_to_device()
  268. test_enwik9_dataset_exceptions()