You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_classification_loader.py 2.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import unittest
  2. import os
  3. from fastNLP.io import DataBundle
  4. from fastNLP.io.loader.classification import YelpFullLoader, YelpPolarityLoader, IMDBLoader, \
  5. SSTLoader, SST2Loader, ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader
  6. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  7. class TestDownload(unittest.TestCase):
  8. def test_download(self):
  9. for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader, ChnSentiCorpLoader]:
  10. loader().download()
  11. def test_load(self):
  12. for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader, ChnSentiCorpLoader]:
  13. data_bundle = loader().load()
  14. print(data_bundle)
  15. class TestLoad(unittest.TestCase):
  16. def test_process_from_file(self):
  17. data_set_dict = {
  18. 'yelp.p': ('test/data_for_tests/io/yelp_review_polarity', YelpPolarityLoader, (6, 6, 6), False),
  19. 'yelp.f': ('test/data_for_tests/io/yelp_review_full', YelpFullLoader, (6, 6, 6), False),
  20. 'sst-2': ('test/data_for_tests/io/SST-2', SST2Loader, (5, 5, 5), True),
  21. 'sst': ('test/data_for_tests/io/SST', SSTLoader, (6, 6, 6), False),
  22. 'imdb': ('test/data_for_tests/io/imdb', IMDBLoader, (6, 6, 6), False),
  23. 'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpLoader, (6, 6, 6), False),
  24. 'THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsLoader, (9, 9, 9), False),
  25. 'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (6, 6, 7), False),
  26. }
  27. for k, v in data_set_dict.items():
  28. path, loader, data_set, warns = v
  29. with self.subTest(path=path):
  30. if warns:
  31. with self.assertWarns(Warning):
  32. data_bundle = loader().load(path)
  33. else:
  34. data_bundle = loader().load(path)
  35. self.assertTrue(isinstance(data_bundle, DataBundle))
  36. self.assertEqual(len(data_set), data_bundle.num_dataset)
  37. for x, y in zip(data_set, data_bundle.iter_datasets()):
  38. name, dataset = y
  39. with self.subTest(split=name):
  40. self.assertEqual(x, len(dataset))