You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_classification_loader.py 2.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import os
  2. import pytest
  3. from fastNLP.io import DataBundle
  4. from fastNLP.io.loader.classification import YelpFullLoader, YelpPolarityLoader, IMDBLoader, \
  5. SSTLoader, SST2Loader, ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader, \
  6. MRLoader, R8Loader, R52Loader, OhsumedLoader, NG20Loader
  7. class TestDownload:
  8. @pytest.mark.skipif('download' not in os.environ, reason="Skip download")
  9. def test_download(self):
  10. for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader, ChnSentiCorpLoader]:
  11. loader().download()
  12. @pytest.mark.skipif('download' not in os.environ, reason="Skip download")
  13. def test_load(self):
  14. for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader, ChnSentiCorpLoader]:
  15. data_bundle = loader().load()
  16. print(data_bundle)
  17. class TestLoad:
  18. def test_process_from_file(self):
  19. data_set_dict = {
  20. 'yelp.p': ('tests/data_for_tests/io/yelp_review_polarity', YelpPolarityLoader, (6, 6, 6), False),
  21. 'yelp.f': ('tests/data_for_tests/io/yelp_review_full', YelpFullLoader, (6, 6, 6), False),
  22. 'sst-2': ('tests/data_for_tests/io/SST-2', SST2Loader, (5, 5, 5), True),
  23. 'sst': ('tests/data_for_tests/io/SST', SSTLoader, (6, 6, 6), False),
  24. 'imdb': ('tests/data_for_tests/io/imdb', IMDBLoader, (6, 6, 6), False),
  25. 'ChnSentiCorp': ('tests/data_for_tests/io/ChnSentiCorp', ChnSentiCorpLoader, (6, 6, 6), False),
  26. 'THUCNews': ('tests/data_for_tests/io/THUCNews', THUCNewsLoader, (9, 9, 9), False),
  27. 'WeiboSenti100k': ('tests/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (6, 7, 6), False),
  28. 'mr': ('tests/data_for_tests/io/mr', MRLoader, (6, 6, 6), False),
  29. 'R8': ('tests/data_for_tests/io/R8', R8Loader, (6, 6, 6), False),
  30. 'R52': ('tests/data_for_tests/io/R52', R52Loader, (6, 6, 6), False),
  31. 'ohsumed': ('tests/data_for_tests/io/R52', OhsumedLoader, (6, 6, 6), False),
  32. '20ng': ('tests/data_for_tests/io/R52', NG20Loader, (6, 6, 6), False),
  33. }
  34. for k, v in data_set_dict.items():
  35. path, loader, data_set, warns = v
  36. data_bundle = loader().load(path)
  37. assert(isinstance(data_bundle, DataBundle))
  38. assert(len(data_set) == data_bundle.num_dataset)
  39. for x, y in zip(data_set, data_bundle.iter_datasets()):
  40. name, dataset = y
  41. assert(x == len(dataset))