You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_squad.py 9.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import pytest
  16. import mindspore.dataset as ds
  17. DATASET_DIR_V1 = '../data/dataset/testSQuAD/SQuAD1'
  18. DATASET_DIR_V2 = '../data/dataset/testSQuAD/SQuAD2'
  19. def test_squad_basic():
  20. """
  21. Feature: SQuADDataset.
  22. Description: test SQuADDataset with repeat, skip and so on.
  23. Expectation: the data is processed successfully.
  24. """
  25. data = ds.SQuADDataset(DATASET_DIR_V1, usage='train', shuffle=False)
  26. data = data.repeat(2)
  27. data = data.skip(3)
  28. expected_result = ["Who is \"The Father of Modern Computers\"?",
  29. "When was John von Neumann's birth date?",
  30. "Where is John von Neumann's birthplace?"]
  31. count = 0
  32. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  33. assert d['question'].item().decode("utf8") == expected_result[count]
  34. count += 1
  35. assert count == 3
  36. def test_squad_num_shards():
  37. """
  38. Feature: SQuADDataset.
  39. Description: test num_shards param of SQuAD dataset.
  40. Expectation: the data is processed successfully.
  41. """
  42. data = ds.SQuADDataset(DATASET_DIR_V1, usage='train',
  43. num_shards=3, shard_id=2)
  44. expected_result = ["Where is John von Neumann's birthplace?"]
  45. count = 0
  46. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  47. assert d['question'].item().decode("utf8") == expected_result[count]
  48. count += 1
  49. assert count == 1
  50. def test_squad_num_samples():
  51. """
  52. Feature: SQuADDataset.
  53. Description: test num_samples param of SQuAD dataset.
  54. Expectation: the data is processed successfully.
  55. """
  56. data = ds.SQuADDataset(DATASET_DIR_V1, usage='train', num_samples=2)
  57. count = 0
  58. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  59. count += 1
  60. assert count == 2
  61. def test_squad_dataset_get_datasetsize():
  62. """
  63. Feature: SQuADDataset.
  64. Description: test get_dataset_size of SQuAD dataset.
  65. Expectation: the data is processed successfully.
  66. """
  67. data = ds.SQuADDataset(DATASET_DIR_V1, usage='train')
  68. size = data.get_dataset_size()
  69. assert size == 3
  70. def test_squad_version1():
  71. """
  72. Feature: SQuADDataset.
  73. Description: test SQuAD 1.1 for train, dev and all.
  74. Expectation: the data is processed successfully.
  75. """
  76. # train
  77. data = ds.SQuADDataset(DATASET_DIR_V1, usage='train', shuffle=False)
  78. expected_result = ["Who is \"The Father of Modern Computers\"?",
  79. "When was John von Neumann's birth date?",
  80. "Where is John von Neumann's birthplace?"]
  81. count = 0
  82. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  83. assert d['question'].item().decode("utf8") == expected_result[count]
  84. count += 1
  85. assert count == 3
  86. # dev
  87. data = ds.SQuADDataset(DATASET_DIR_V1, usage='dev', shuffle=False)
  88. expected_result = ["\"The Mathematical Principles of Natural Philosophy\" is a philosophical philosophy " +
  89. "of physics created by British Cognitive Isaac Newton. It was first published in 1687.",
  90. "\"The Mathematical Principles of Natural Philosophy\" is a philosophical philosophy " +
  91. "of physics created by British Cognitive Isaac Newton. It was first published in 1687."]
  92. count = 0
  93. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  94. assert d['context'].item().decode("utf8") == expected_result[count]
  95. count += 1
  96. assert count == 2
  97. # all
  98. data = ds.SQuADDataset(DATASET_DIR_V1, usage='all', shuffle=False)
  99. expected_result = [[0], [122, 122, 122], [18], [162, 162, 162], [55]]
  100. count = 0
  101. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  102. assert [i.item() for i in d['answer_start']] == expected_result[count]
  103. count += 1
  104. assert count == 5
  105. def test_squad_version2():
  106. """
  107. Feature: SQuADDataset.
  108. Description: test SQuAD2.0 for train, dev and all.
  109. Expectation: the data is processed successfully.
  110. """
  111. # train
  112. data = ds.SQuADDataset(DATASET_DIR_V2, usage='train', shuffle=False)
  113. expected_result = ["Stephen William Hawking, born on January 8, 1942 in Oxford, England, " +
  114. "is one of the greatest modern physicists.",
  115. "Stephen William Hawking, born on January 8, 1942 in Oxford, England, " +
  116. "is one of the greatest modern physicists."]
  117. count = 0
  118. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  119. assert d['context'].item().decode("utf8") == expected_result[count]
  120. count += 1
  121. assert count == 2
  122. # dev
  123. data = ds.SQuADDataset(DATASET_DIR_V2, usage='dev', shuffle=False)
  124. expected_result = ["What is the lifestyle of dolphins?",
  125. "Who ate the squid?"]
  126. count = 0
  127. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  128. assert d['question'].item().decode("utf8") == expected_result[count]
  129. count += 1
  130. assert count == 2
  131. # all
  132. data = ds.SQuADDataset(DATASET_DIR_V2, usage='all', shuffle=False)
  133. expected_result = [["Oxford, England"],
  134. ["live in groups", "live in groups",
  135. "live in groups", "live in groups"],
  136. ["January 8, 1942"], [""]]
  137. count = 0
  138. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  139. result = [i.item().decode("utf8") for i in d['text']]
  140. assert result == expected_result[count]
  141. count += 1
  142. assert count == 4
  143. def test_squad_to_device():
  144. """
  145. Feature: SQuADDataset.
  146. Description: test SQuAD with to_device.
  147. Expectation: the data is processed successfully.
  148. """
  149. data = ds.SQuADDataset(DATASET_DIR_V1, usage='train', shuffle=False)
  150. data = data.to_device()
  151. data.send()
  152. def test_squad_invalid_dir():
  153. """
  154. Feature: SQuADDataset.
  155. Description: test SQuAD with invalid dir.
  156. Expectation: throw correct error and message.
  157. """
  158. invalid_dataset_dir = '../data/dataset/invalid_dir'
  159. with pytest.raises(ValueError) as info:
  160. _ = ds.SQuADDataset(invalid_dataset_dir, usage='train', shuffle=False)
  161. assert "The folder " + invalid_dataset_dir + " does not exist or is not a directory or permission denied!" \
  162. in str(info.value)
  163. assert invalid_dataset_dir in str(info.value)
  164. def test_squad_exception():
  165. """
  166. Feature: SQuADDataset.
  167. Description: test file info in err msg when exception occur of SQuAD dataset.
  168. Expectation: unable to read in data.
  169. """
  170. def exception_func(item):
  171. raise Exception("Error occur!")
  172. try:
  173. data = ds.SQuADDataset(DATASET_DIR_V1, usage='train')
  174. data = data.map(operations=exception_func, input_columns=["context"],
  175. num_parallel_workers=1)
  176. for _ in data.create_dict_iterator():
  177. pass
  178. assert False
  179. except RuntimeError as e:
  180. assert "map operation: [PyFunc] failed. The corresponding data files" \
  181. in str(e)
  182. try:
  183. data = ds.SQuADDataset(DATASET_DIR_V1, usage='train')
  184. data = data.map(operations=exception_func, input_columns=["question"],
  185. num_parallel_workers=1)
  186. for _ in data.create_dict_iterator():
  187. pass
  188. assert False
  189. except RuntimeError as e:
  190. assert "map operation: [PyFunc] failed. The corresponding data files" \
  191. in str(e)
  192. try:
  193. data = ds.SQuADDataset(DATASET_DIR_V1, usage='train')
  194. data = data.map(operations=exception_func, input_columns=["answer_start"],
  195. num_parallel_workers=1)
  196. for _ in data.create_dict_iterator():
  197. pass
  198. assert False
  199. except RuntimeError as e:
  200. assert "map operation: [PyFunc] failed. The corresponding data files" \
  201. in str(e)
  202. try:
  203. data = ds.SQuADDataset(DATASET_DIR_V1, usage='train')
  204. data = data.map(operations=exception_func, input_columns=["text"],
  205. num_parallel_workers=1)
  206. for _ in data.create_dict_iterator():
  207. pass
  208. assert False
  209. except RuntimeError as e:
  210. assert "map operation: [PyFunc] failed. The corresponding data files" \
  211. in str(e)
  212. if __name__ == "__main__":
  213. test_squad_basic()
  214. test_squad_num_shards()
  215. test_squad_num_samples()
  216. test_squad_dataset_get_datasetsize()
  217. test_squad_version1()
  218. test_squad_version2()
  219. test_squad_to_device()
  220. test_squad_invalid_dir()
  221. test_squad_exception()