You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_classification.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import unittest
  2. import os
  3. from fastNLP.io import DataBundle
  4. from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe, \
  5. AGsNewsPipe, DBPediaPipe
  6. from fastNLP.io.pipe.classification import ChnSentiCorpPipe, THUCNewsPipe, WeiboSenti100kPipe
  7. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  8. class TestClassificationPipe(unittest.TestCase):
  9. def test_process_from_file(self):
  10. for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]:
  11. with self.subTest(pipe=pipe):
  12. print(pipe)
  13. data_bundle = pipe(tokenizer='raw').process_from_file()
  14. print(data_bundle)
  15. class TestRunPipe(unittest.TestCase):
  16. def test_load(self):
  17. for pipe in [IMDBPipe]:
  18. data_bundle = pipe(tokenizer='raw').process_from_file('test/data_for_tests/io/imdb')
  19. print(data_bundle)
  20. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  21. class TestCNClassificationPipe(unittest.TestCase):
  22. def test_process_from_file(self):
  23. for pipe in [ChnSentiCorpPipe]:
  24. with self.subTest(pipe=pipe):
  25. data_bundle = pipe(bigrams=True, trigrams=True).process_from_file()
  26. print(data_bundle)
  27. class TestRunClassificationPipe(unittest.TestCase):
  28. def test_process_from_file(self):
  29. data_set_dict = {
  30. 'yelp.p': ('test/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe, (6, 6, 6), (1176, 2), False),
  31. 'yelp.f': ('test/data_for_tests/io/yelp_review_full', YelpFullPipe, (6, 6, 6), (1166, 5), False),
  32. 'sst-2': ('test/data_for_tests/io/SST-2', SST2Pipe, (5, 5, 5), (139, 2), True),
  33. 'sst': ('test/data_for_tests/io/SST', SSTPipe, (6, 354, 6), (232, 5), False),
  34. 'imdb': ('test/data_for_tests/io/imdb', IMDBPipe, (6, 6, 6), (1670, 2), False),
  35. 'ag': ('test/data_for_tests/io/ag', AGsNewsPipe, (5, 4), (257, 4), False),
  36. 'dbpedia': ('test/data_for_tests/io/dbpedia', DBPediaPipe, (5, 14), (496, 14), False),
  37. 'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe, (6, 6, 6), (529, 1296, 1483, 2), False),
  38. 'Chn-THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsPipe, (9, 9, 9), (1864, 9), False),
  39. 'Chn-WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe, (7, 6, 6), (452, 2), False),
  40. }
  41. for k, v in data_set_dict.items():
  42. path, pipe, data_set, vocab, warns = v
  43. with self.subTest(pipe=pipe):
  44. if 'Chn' not in k:
  45. if warns:
  46. with self.assertWarns(Warning):
  47. data_bundle = pipe(tokenizer='raw').process_from_file(path)
  48. else:
  49. data_bundle = pipe(tokenizer='raw').process_from_file(path)
  50. else:
  51. data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path)
  52. self.assertTrue(isinstance(data_bundle, DataBundle))
  53. self.assertEqual(len(data_set), data_bundle.num_dataset)
  54. for x, y in zip(data_set, data_bundle.iter_datasets()):
  55. name, dataset = y
  56. self.assertEqual(x, len(dataset))
  57. self.assertEqual(len(vocab), data_bundle.num_vocab)
  58. for x, y in zip(vocab, data_bundle.iter_vocabs()):
  59. name, vocabs = y
  60. self.assertEqual(x, len(vocabs))