You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_matching.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import pytest
  2. import os
  3. from fastNLP.io import DataBundle
  4. from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, QuoraPipe, MNLIPipe, \
  5. CNXNLIPipe, BQCorpusPipe, LCQMCPipe
  6. from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, QuoraBertPipe, MNLIBertPipe, \
  7. CNXNLIBertPipe, BQCorpusBertPipe, LCQMCBertPipe
  8. @pytest.mark.skipif('download' not in os.environ, reason="Skip download")
  9. class TestMatchingPipe:
  10. def test_process_from_file(self):
  11. for pipe in [SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe]:
  12. print(pipe)
  13. data_bundle = pipe(tokenizer='raw').process_from_file()
  14. print(data_bundle)
  15. @pytest.mark.skipif('download' not in os.environ, reason="Skip download")
  16. class TestMatchingBertPipe:
  17. def test_process_from_file(self):
  18. for pipe in [SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe]:
  19. print(pipe)
  20. data_bundle = pipe(tokenizer='raw').process_from_file()
  21. print(data_bundle)
  22. class TestRunMatchingPipe:
  23. def test_load(self):
  24. data_set_dict = {
  25. 'RTE': ('tests/data_for_tests/io/RTE', RTEPipe, RTEBertPipe, (5, 5, 5), (449, 2), True),
  26. 'SNLI': ('tests/data_for_tests/io/SNLI', SNLIPipe, SNLIBertPipe, (5, 5, 5), (110, 3), False),
  27. 'QNLI': ('tests/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True),
  28. 'MNLI': ('tests/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True),
  29. 'BQCorpus': ('tests/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False),
  30. 'XNLI': ('tests/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 6, 8), (39, 3), False),
  31. 'LCQMC': ('tests/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 5, 6), (36, 2), False),
  32. }
  33. for k, v in data_set_dict.items():
  34. path, pipe1, pipe2, data_set, vocab, warns = v
  35. if warns:
  36. data_bundle1 = pipe1(tokenizer='raw').process_from_file(path)
  37. data_bundle2 = pipe2(tokenizer='raw').process_from_file(path)
  38. else:
  39. data_bundle1 = pipe1(tokenizer='raw').process_from_file(path)
  40. data_bundle2 = pipe2(tokenizer='raw').process_from_file(path)
  41. assert(isinstance(data_bundle1, DataBundle))
  42. assert(len(data_set) == data_bundle1.num_dataset)
  43. print(k)
  44. print(data_bundle1)
  45. print(data_bundle2)
  46. for x, y in zip(data_set, data_bundle1.iter_datasets()):
  47. name, dataset = y
  48. assert(x == len(dataset))
  49. assert(len(data_set) == data_bundle2.num_dataset)
  50. for x, y in zip(data_set, data_bundle2.iter_datasets()):
  51. name, dataset = y
  52. assert(x == len(dataset))
  53. assert(len(vocab) == data_bundle1.num_vocab)
  54. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  55. name, vocabs = y
  56. assert(x == len(vocabs))
  57. assert(len(vocab) == data_bundle2.num_vocab)
  58. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  59. name, vocabs = y
  60. assert(x + 1 if name == 'words' else x == len(vocabs))
  61. @pytest.mark.skipif('download' not in os.environ, reason="Skip download")
  62. def test_spacy(self):
  63. data_set_dict = {
  64. 'Quora': ('tests/data_for_tests/io/Quora', QuoraPipe, QuoraBertPipe, (2, 2, 2), (93, 2)),
  65. }
  66. for k, v in data_set_dict.items():
  67. path, pipe1, pipe2, data_set, vocab = v
  68. data_bundle1 = pipe1(tokenizer='spacy').process_from_file(path)
  69. data_bundle2 = pipe2(tokenizer='spacy').process_from_file(path)
  70. assert(isinstance(data_bundle1, DataBundle))
  71. assert(len(data_set) == data_bundle1.num_dataset)
  72. print(k)
  73. print(data_bundle1)
  74. print(data_bundle2)
  75. for x, y in zip(data_set, data_bundle1.iter_datasets()):
  76. name, dataset = y
  77. assert(x == len(dataset))
  78. assert(len(data_set) == data_bundle2.num_dataset)
  79. for x, y in zip(data_set, data_bundle2.iter_datasets()):
  80. name, dataset = y
  81. assert(x == len(dataset))
  82. assert(len(vocab) == data_bundle1.num_vocab)
  83. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  84. name, vocabs = y
  85. assert(x == len(vocabs))
  86. assert(len(vocab) == data_bundle2.num_vocab)
  87. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  88. name, vocabs = y
  89. assert(x + 1 if name == 'words' else x == len(vocabs))