You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_matching_loader.py 2.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import pytest
  2. import os
  3. from fastNLP.io import DataBundle
  4. from fastNLP.io.loader.matching import RTELoader, QNLILoader, SNLILoader, QuoraLoader, MNLILoader, \
  5. BQCorpusLoader, CNXNLILoader, LCQMCLoader
  6. @pytest.mark.skipif('download' not in os.environ, reason="Skip download")
  7. class TestMatchingDownload:
  8. def test_download(self):
  9. for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]:
  10. loader().download()
  11. with pytest.raises(Exception):
  12. QuoraLoader().load()
  13. def test_load(self):
  14. for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]:
  15. data_bundle = loader().load()
  16. print(data_bundle)
  17. class TestMatchingLoad:
  18. def test_load(self):
  19. data_set_dict = {
  20. 'RTE': ('tests/data_for_tests/io/RTE', RTELoader, (5, 5, 5), True),
  21. 'SNLI': ('tests/data_for_tests/io/SNLI', SNLILoader, (5, 5, 5), False),
  22. 'QNLI': ('tests/data_for_tests/io/QNLI', QNLILoader, (5, 5, 5), True),
  23. 'MNLI': ('tests/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True),
  24. 'Quora': ('tests/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False),
  25. 'BQCorpus': ('tests/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False),
  26. 'XNLI': ('tests/data_for_tests/io/XNLI', CNXNLILoader, (6, 6, 8), False),
  27. 'LCQMC': ('tests/data_for_tests/io/LCQMC', LCQMCLoader, (6, 5, 6), False),
  28. }
  29. for k, v in data_set_dict.items():
  30. path, loader, instance, warns = v
  31. if warns:
  32. data_bundle = loader().load(path)
  33. else:
  34. data_bundle = loader().load(path)
  35. assert(isinstance(data_bundle, DataBundle))
  36. assert(len(instance) == data_bundle.num_dataset)
  37. for x, y in zip(instance, data_bundle.iter_datasets()):
  38. name, dataset = y
  39. assert(x == len(dataset))