You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_matching_loader.py 2.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import unittest
  2. import os
  3. from fastNLP.io import DataBundle
  4. from fastNLP.io.loader.matching import RTELoader, QNLILoader, SNLILoader, QuoraLoader, MNLILoader, \
  5. BQCorpusLoader, XNLILoader, LCQMCLoader
  6. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  7. class TestMatchingDownload(unittest.TestCase):
  8. def test_download(self):
  9. for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]:
  10. loader().download()
  11. with self.assertRaises(Exception):
  12. QuoraLoader().load()
  13. def test_load(self):
  14. for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]:
  15. data_bundle = loader().load()
  16. print(data_bundle)
  17. class TestMatchingLoad(unittest.TestCase):
  18. def test_load(self):
  19. data_set_dict = {
  20. 'RTE': ('test/data_for_tests/io/RTE', RTELoader, (5, 5, 5), True),
  21. 'SNLI': ('test/data_for_tests/io/SNLI', SNLILoader, (5, 5, 5), False),
  22. 'QNLI': ('test/data_for_tests/io/QNLI', QNLILoader, (5, 5, 5), True),
  23. 'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True),
  24. 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False),
  25. 'XNLI': ('test/data_for_tests/io/XNLI', XNLILoader, (6, 7, 6), False),
  26. 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (5, 6, 6), False),
  27. }
  28. for k, v in data_set_dict.items():
  29. path, loader, instance, warns = v
  30. if warns:
  31. with self.assertWarns(Warning):
  32. data_bundle = loader().load(path)
  33. else:
  34. data_bundle = loader().load(path)
  35. self.assertTrue(isinstance(data_bundle, DataBundle))
  36. self.assertEqual(len(instance), data_bundle.num_dataset)
  37. for x, y in zip(instance, data_bundle.iter_datasets()):
  38. name, dataset = y
  39. self.assertEqual(x, len(dataset))