You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_clue.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import os
  16. import pytest
  17. import mindspore.dataset as ds
  18. def test_clue():
  19. """
  20. Test CLUE with repeat, skip and so on
  21. """
  22. TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
  23. buffer = []
  24. data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
  25. data = data.repeat(2)
  26. data = data.skip(3)
  27. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  28. buffer.append({
  29. 'label': d['label'].item().decode("utf8"),
  30. 'sentence1': d['sentence1'].item().decode("utf8"),
  31. 'sentence2': d['sentence2'].item().decode("utf8")
  32. })
  33. assert len(buffer) == 3
  34. def test_clue_num_shards():
  35. """
  36. Test num_shards param of CLUE dataset
  37. """
  38. TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
  39. buffer = []
  40. data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_shards=3, shard_id=1)
  41. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  42. buffer.append({
  43. 'label': d['label'].item().decode("utf8"),
  44. 'sentence1': d['sentence1'].item().decode("utf8"),
  45. 'sentence2': d['sentence2'].item().decode("utf8")
  46. })
  47. assert len(buffer) == 1
  48. def test_clue_num_samples():
  49. """
  50. Test num_samples param of CLUE dataset
  51. """
  52. TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
  53. data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', num_samples=2)
  54. count = 0
  55. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  56. count += 1
  57. assert count == 2
  58. def test_textline_dataset_get_datasetsize():
  59. """
  60. Test get_dataset_size of CLUE dataset
  61. """
  62. TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
  63. data = ds.TextFileDataset(TRAIN_FILE)
  64. size = data.get_dataset_size()
  65. assert size == 3
  66. def test_clue_afqmc():
  67. """
  68. Test AFQMC for train, test and evaluation
  69. """
  70. TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
  71. TEST_FILE = '../data/dataset/testCLUE/afqmc/test.json'
  72. EVAL_FILE = '../data/dataset/testCLUE/afqmc/dev.json'
  73. # train
  74. buffer = []
  75. data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
  76. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  77. buffer.append({
  78. 'label': d['label'].item().decode("utf8"),
  79. 'sentence1': d['sentence1'].item().decode("utf8"),
  80. 'sentence2': d['sentence2'].item().decode("utf8")
  81. })
  82. assert len(buffer) == 3
  83. # test
  84. buffer = []
  85. data = ds.CLUEDataset(TEST_FILE, task='AFQMC', usage='test', shuffle=False)
  86. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  87. buffer.append({
  88. 'id': d['id'],
  89. 'sentence1': d['sentence1'].item().decode("utf8"),
  90. 'sentence2': d['sentence2'].item().decode("utf8")
  91. })
  92. assert len(buffer) == 3
  93. # evaluation
  94. buffer = []
  95. data = ds.CLUEDataset(EVAL_FILE, task='AFQMC', usage='eval', shuffle=False)
  96. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  97. buffer.append({
  98. 'label': d['label'].item().decode("utf8"),
  99. 'sentence1': d['sentence1'].item().decode("utf8"),
  100. 'sentence2': d['sentence2'].item().decode("utf8")
  101. })
  102. assert len(buffer) == 3
  103. def test_clue_cmnli():
  104. """
  105. Test CMNLI for train, test and evaluation
  106. """
  107. TRAIN_FILE = '../data/dataset/testCLUE/cmnli/train.json'
  108. TEST_FILE = '../data/dataset/testCLUE/cmnli/test.json'
  109. EVAL_FILE = '../data/dataset/testCLUE/cmnli/dev.json'
  110. # train
  111. buffer = []
  112. data = ds.CLUEDataset(TRAIN_FILE, task='CMNLI', usage='train', shuffle=False)
  113. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  114. buffer.append({
  115. 'label': d['label'].item().decode("utf8"),
  116. 'sentence1': d['sentence1'].item().decode("utf8"),
  117. 'sentence2': d['sentence2'].item().decode("utf8")
  118. })
  119. assert len(buffer) == 3
  120. # test
  121. buffer = []
  122. data = ds.CLUEDataset(TEST_FILE, task='CMNLI', usage='test', shuffle=False)
  123. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  124. buffer.append({
  125. 'id': d['id'],
  126. 'sentence1': d['sentence1'],
  127. 'sentence2': d['sentence2']
  128. })
  129. assert len(buffer) == 3
  130. # eval
  131. buffer = []
  132. data = ds.CLUEDataset(EVAL_FILE, task='CMNLI', usage='eval', shuffle=False)
  133. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  134. buffer.append({
  135. 'label': d['label'],
  136. 'sentence1': d['sentence1'],
  137. 'sentence2': d['sentence2']
  138. })
  139. assert len(buffer) == 3
  140. def test_clue_csl():
  141. """
  142. Test CSL for train, test and evaluation
  143. """
  144. TRAIN_FILE = '../data/dataset/testCLUE/csl/train.json'
  145. TEST_FILE = '../data/dataset/testCLUE/csl/test.json'
  146. EVAL_FILE = '../data/dataset/testCLUE/csl/dev.json'
  147. # train
  148. buffer = []
  149. data = ds.CLUEDataset(TRAIN_FILE, task='CSL', usage='train', shuffle=False)
  150. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  151. buffer.append({
  152. 'id': d['id'],
  153. 'abst': d['abst'].item().decode("utf8"),
  154. 'keyword': [i.item().decode("utf8") for i in d['keyword']],
  155. 'label': d['label'].item().decode("utf8")
  156. })
  157. assert len(buffer) == 3
  158. # test
  159. buffer = []
  160. data = ds.CLUEDataset(TEST_FILE, task='CSL', usage='test', shuffle=False)
  161. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  162. buffer.append({
  163. 'id': d['id'],
  164. 'abst': d['abst'].item().decode("utf8"),
  165. 'keyword': [i.item().decode("utf8") for i in d['keyword']],
  166. })
  167. assert len(buffer) == 3
  168. # eval
  169. buffer = []
  170. data = ds.CLUEDataset(EVAL_FILE, task='CSL', usage='eval', shuffle=False)
  171. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  172. buffer.append({
  173. 'id': d['id'],
  174. 'abst': d['abst'].item().decode("utf8"),
  175. 'keyword': [i.item().decode("utf8") for i in d['keyword']],
  176. 'label': d['label'].item().decode("utf8")
  177. })
  178. assert len(buffer) == 3
  179. def test_clue_iflytek():
  180. """
  181. Test IFLYTEK for train, test and evaluation
  182. """
  183. TRAIN_FILE = '../data/dataset/testCLUE/iflytek/train.json'
  184. TEST_FILE = '../data/dataset/testCLUE/iflytek/test.json'
  185. EVAL_FILE = '../data/dataset/testCLUE/iflytek/dev.json'
  186. # train
  187. buffer = []
  188. data = ds.CLUEDataset(TRAIN_FILE, task='IFLYTEK', usage='train', shuffle=False)
  189. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  190. buffer.append({
  191. 'label': d['label'].item().decode("utf8"),
  192. 'label_des': d['label_des'].item().decode("utf8"),
  193. 'sentence': d['sentence'].item().decode("utf8"),
  194. })
  195. assert len(buffer) == 3
  196. # test
  197. buffer = []
  198. data = ds.CLUEDataset(TEST_FILE, task='IFLYTEK', usage='test', shuffle=False)
  199. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  200. buffer.append({
  201. 'id': d['id'],
  202. 'sentence': d['sentence'].item().decode("utf8")
  203. })
  204. assert len(buffer) == 3
  205. # eval
  206. buffer = []
  207. data = ds.CLUEDataset(EVAL_FILE, task='IFLYTEK', usage='eval', shuffle=False)
  208. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  209. buffer.append({
  210. 'label': d['label'].item().decode("utf8"),
  211. 'label_des': d['label_des'].item().decode("utf8"),
  212. 'sentence': d['sentence'].item().decode("utf8")
  213. })
  214. assert len(buffer) == 3
  215. def test_clue_tnews():
  216. """
  217. Test TNEWS for train, test and evaluation
  218. """
  219. TRAIN_FILE = '../data/dataset/testCLUE/tnews/train.json'
  220. TEST_FILE = '../data/dataset/testCLUE/tnews/test.json'
  221. EVAL_FILE = '../data/dataset/testCLUE/tnews/dev.json'
  222. # train
  223. buffer = []
  224. data = ds.CLUEDataset(TRAIN_FILE, task='TNEWS', usage='train', shuffle=False)
  225. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  226. buffer.append({
  227. 'label': d['label'].item().decode("utf8"),
  228. 'label_desc': d['label_desc'].item().decode("utf8"),
  229. 'sentence': d['sentence'].item().decode("utf8"),
  230. 'keywords':
  231. d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
  232. })
  233. assert len(buffer) == 3
  234. # test
  235. buffer = []
  236. data = ds.CLUEDataset(TEST_FILE, task='TNEWS', usage='test', shuffle=False)
  237. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  238. buffer.append({
  239. 'id': d['id'],
  240. 'sentence': d['sentence'].item().decode("utf8"),
  241. 'keywords':
  242. d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
  243. })
  244. assert len(buffer) == 3
  245. # eval
  246. buffer = []
  247. data = ds.CLUEDataset(EVAL_FILE, task='TNEWS', usage='eval', shuffle=False)
  248. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  249. buffer.append({
  250. 'label': d['label'].item().decode("utf8"),
  251. 'label_desc': d['label_desc'].item().decode("utf8"),
  252. 'sentence': d['sentence'].item().decode("utf8"),
  253. 'keywords':
  254. d['keywords'].item().decode("utf8") if d['keywords'].size > 0 else d['keywords']
  255. })
  256. assert len(buffer) == 3
  257. def test_clue_wsc():
  258. """
  259. Test WSC for train, test and evaluation
  260. """
  261. TRAIN_FILE = '../data/dataset/testCLUE/wsc/train.json'
  262. TEST_FILE = '../data/dataset/testCLUE/wsc/test.json'
  263. EVAL_FILE = '../data/dataset/testCLUE/wsc/dev.json'
  264. # train
  265. buffer = []
  266. data = ds.CLUEDataset(TRAIN_FILE, task='WSC', usage='train')
  267. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  268. buffer.append({
  269. 'span1_index': d['span1_index'],
  270. 'span2_index': d['span2_index'],
  271. 'span1_text': d['span1_text'].item().decode("utf8"),
  272. 'span2_text': d['span2_text'].item().decode("utf8"),
  273. 'idx': d['idx'],
  274. 'label': d['label'].item().decode("utf8"),
  275. 'text': d['text'].item().decode("utf8")
  276. })
  277. assert len(buffer) == 3
  278. # test
  279. buffer = []
  280. data = ds.CLUEDataset(TEST_FILE, task='WSC', usage='test')
  281. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  282. buffer.append({
  283. 'span1_index': d['span1_index'],
  284. 'span2_index': d['span2_index'],
  285. 'span1_text': d['span1_text'].item().decode("utf8"),
  286. 'span2_text': d['span2_text'].item().decode("utf8"),
  287. 'idx': d['idx'],
  288. 'text': d['text'].item().decode("utf8")
  289. })
  290. assert len(buffer) == 3
  291. # eval
  292. buffer = []
  293. data = ds.CLUEDataset(EVAL_FILE, task='WSC', usage='eval')
  294. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  295. buffer.append({
  296. 'span1_index': d['span1_index'],
  297. 'span2_index': d['span2_index'],
  298. 'span1_text': d['span1_text'].item().decode("utf8"),
  299. 'span2_text': d['span2_text'].item().decode("utf8"),
  300. 'idx': d['idx'],
  301. 'label': d['label'].item().decode("utf8"),
  302. 'text': d['text'].item().decode("utf8")
  303. })
  304. assert len(buffer) == 3
  305. def test_clue_to_device():
  306. """
  307. Test CLUE with to_device
  308. """
  309. TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
  310. data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
  311. data = data.to_device()
  312. data.send()
  313. def test_clue_invalid_files():
  314. """
  315. Test CLUE with invalid files
  316. """
  317. AFQMC_DIR = '../data/dataset/testCLUE/afqmc'
  318. afqmc_train_json = os.path.join(AFQMC_DIR)
  319. with pytest.raises(ValueError) as info:
  320. _ = ds.CLUEDataset(afqmc_train_json, task='AFQMC', usage='train', shuffle=False)
  321. assert "The following patterns did not match any files" in str(info.value)
  322. assert AFQMC_DIR in str(info.value)
  323. def test_clue_exception_file_path():
  324. """
  325. Test file info in err msg when exception occur of CLUE dataset
  326. """
  327. TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
  328. def exception_func(item):
  329. raise Exception("Error occur!")
  330. try:
  331. data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
  332. data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
  333. for _ in data.create_dict_iterator():
  334. pass
  335. assert False
  336. except RuntimeError as e:
  337. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  338. try:
  339. data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
  340. data = data.map(operations=exception_func, input_columns=["sentence1"], num_parallel_workers=1)
  341. for _ in data.create_dict_iterator():
  342. pass
  343. assert False
  344. except RuntimeError as e:
  345. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  346. try:
  347. data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
  348. data = data.map(operations=exception_func, input_columns=["sentence2"], num_parallel_workers=1)
  349. for _ in data.create_dict_iterator():
  350. pass
  351. assert False
  352. except RuntimeError as e:
  353. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  354. if __name__ == "__main__":
  355. test_clue()
  356. test_clue_num_shards()
  357. test_clue_num_samples()
  358. test_textline_dataset_get_datasetsize()
  359. test_clue_afqmc()
  360. test_clue_cmnli()
  361. test_clue_csl()
  362. test_clue_iflytek()
  363. test_clue_tnews()
  364. test_clue_wsc()
  365. test_clue_to_device()
  366. test_clue_invalid_files()
  367. test_clue_exception_file_path()