You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_classification.py 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import unittest
  2. import os
  3. from fastNLP.io import DataBundle
  4. from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe, \
  5. AGsNewsPipe, DBPediaPipe
  6. from fastNLP.io.pipe.classification import ChnSentiCorpPipe, THUCNewsPipe, WeiboSenti100kPipe
  7. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  8. class TestClassificationPipe(unittest.TestCase):
  9. def test_process_from_file(self):
  10. for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]:
  11. with self.subTest(pipe=pipe):
  12. print(pipe)
  13. data_bundle = pipe(tokenizer='raw').process_from_file()
  14. print(data_bundle)
  15. class TestRunPipe(unittest.TestCase):
  16. def test_load(self):
  17. for pipe in [IMDBPipe]:
  18. data_bundle = pipe(tokenizer='raw').process_from_file('test/data_for_tests/io/imdb')
  19. print(data_bundle)
  20. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  21. class TestCNClassificationPipe(unittest.TestCase):
  22. def test_process_from_file(self):
  23. for pipe in [ChnSentiCorpPipe]:
  24. with self.subTest(pipe=pipe):
  25. data_bundle = pipe(bigrams=True, trigrams=True).process_from_file()
  26. print(data_bundle)
  27. class TestRunClassificationPipe(unittest.TestCase):
  28. def test_process_from_file(self):
  29. data_set_dict = {
  30. 'yelp.p': ('test/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe,
  31. {'train': 6, 'dev': 6, 'test': 6}, {'words': 1176, 'target': 2},
  32. False),
  33. 'yelp.f': ('test/data_for_tests/io/yelp_review_full', YelpFullPipe,
  34. {'train': 6, 'dev': 6, 'test': 6}, {'words': 1166, 'target': 5},
  35. False),
  36. 'sst-2': ('test/data_for_tests/io/SST-2', SST2Pipe,
  37. {'train': 5, 'dev': 5, 'test': 5}, {'words': 139, 'target': 2},
  38. True),
  39. 'sst': ('test/data_for_tests/io/SST', SSTPipe,
  40. {'train': 354, 'dev': 6, 'test': 6}, {'words': 232, 'target': 5},
  41. False),
  42. 'imdb': ('test/data_for_tests/io/imdb', IMDBPipe,
  43. {'train': 6, 'dev': 6, 'test': 6}, {'words': 1670, 'target': 2},
  44. False),
  45. 'ag': ('test/data_for_tests/io/ag', AGsNewsPipe,
  46. {'train': 4, 'test': 5}, {'words': 257, 'target': 4},
  47. False),
  48. 'dbpedia': ('test/data_for_tests/io/dbpedia', DBPediaPipe,
  49. {'train': 14, 'test': 5}, {'words': 496, 'target': 14},
  50. False),
  51. 'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe,
  52. {'train': 6, 'dev': 6, 'test': 6},
  53. {'chars': 529, 'bigrams': 1296, 'trigrams': 1483, 'target': 2},
  54. False),
  55. 'Chn-THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsPipe,
  56. {'train': 9, 'dev': 9, 'test': 9}, {'chars': 1864, 'target': 9},
  57. False),
  58. 'Chn-WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe,
  59. {'train': 6, 'dev': 6, 'test': 7}, {'chars': 452, 'target': 2},
  60. False),
  61. }
  62. for k, v in data_set_dict.items():
  63. path, pipe, data_set, vocab, warns = v
  64. with self.subTest(path=path):
  65. if 'Chn' not in k:
  66. if warns:
  67. with self.assertWarns(Warning):
  68. data_bundle = pipe(tokenizer='raw').process_from_file(path)
  69. else:
  70. data_bundle = pipe(tokenizer='raw').process_from_file(path)
  71. else:
  72. data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path)
  73. self.assertTrue(isinstance(data_bundle, DataBundle))
  74. self.assertEqual(len(data_set), data_bundle.num_dataset)
  75. for name, dataset in data_bundle.iter_datasets():
  76. self.assertTrue(name in data_set.keys())
  77. self.assertEqual(data_set[name], len(dataset))
  78. self.assertEqual(len(vocab), data_bundle.num_vocab)
  79. for name, vocabs in data_bundle.iter_vocabs():
  80. self.assertTrue(name in vocab.keys())
  81. self.assertEqual(vocab[name], len(vocabs))