You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_classification.py 4.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import pytest
  2. import os
  3. from fastNLP.io import DataBundle
  4. from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe, \
  5. AGsNewsPipe, DBPediaPipe
  6. from fastNLP.io.pipe.classification import ChnSentiCorpPipe, THUCNewsPipe, WeiboSenti100kPipe
  7. @pytest.mark.skipif('download' not in os.environ, reason="Skip download")
  8. class TestClassificationPipe:
  9. def test_process_from_file(self):
  10. for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]:
  11. print(pipe)
  12. data_bundle = pipe(tokenizer='raw').process_from_file()
  13. print(data_bundle)
  14. class TestRunPipe:
  15. def test_load(self):
  16. for pipe in [IMDBPipe]:
  17. data_bundle = pipe(tokenizer='raw').process_from_file('tests/data_for_tests/io/imdb')
  18. print(data_bundle)
  19. @pytest.mark.skipif('download' not in os.environ, reason="Skip download")
  20. class TestCNClassificationPipe:
  21. def test_process_from_file(self):
  22. for pipe in [ChnSentiCorpPipe]:
  23. data_bundle = pipe(bigrams=True, trigrams=True).process_from_file()
  24. print(data_bundle)
  25. @pytest.mark.skipif('download' not in os.environ, reason="Skip download")
  26. class TestRunClassificationPipe:
  27. def test_process_from_file(self):
  28. data_set_dict = {
  29. 'yelp.p': ('tests/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe,
  30. {'train': 6, 'dev': 6, 'test': 6}, {'words': 1176, 'target': 2},
  31. False),
  32. 'yelp.f': ('tests/data_for_tests/io/yelp_review_full', YelpFullPipe,
  33. {'train': 6, 'dev': 6, 'test': 6}, {'words': 1166, 'target': 5},
  34. False),
  35. 'sst-2': ('tests/data_for_tests/io/SST-2', SST2Pipe,
  36. {'train': 5, 'dev': 5, 'test': 5}, {'words': 139, 'target': 2},
  37. True),
  38. 'sst': ('tests/data_for_tests/io/SST', SSTPipe,
  39. {'train': 354, 'dev': 6, 'test': 6}, {'words': 232, 'target': 5},
  40. False),
  41. 'imdb': ('tests/data_for_tests/io/imdb', IMDBPipe,
  42. {'train': 6, 'dev': 6, 'test': 6}, {'words': 1670, 'target': 2},
  43. False),
  44. 'ag': ('tests/data_for_tests/io/ag', AGsNewsPipe,
  45. {'train': 4, 'test': 5}, {'words': 257, 'target': 4},
  46. False),
  47. 'dbpedia': ('tests/data_for_tests/io/dbpedia', DBPediaPipe,
  48. {'train': 14, 'test': 5}, {'words': 496, 'target': 14},
  49. False),
  50. 'ChnSentiCorp': ('tests/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe,
  51. {'train': 6, 'dev': 6, 'test': 6},
  52. {'chars': 529, 'bigrams': 1296, 'trigrams': 1483, 'target': 2},
  53. False),
  54. 'Chn-THUCNews': ('tests/data_for_tests/io/THUCNews', THUCNewsPipe,
  55. {'train': 9, 'dev': 9, 'test': 9}, {'chars': 1864, 'target': 9},
  56. False),
  57. 'Chn-WeiboSenti100k': ('tests/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe,
  58. {'train': 6, 'dev': 6, 'test': 7}, {'chars': 452, 'target': 2},
  59. False),
  60. }
  61. for k, v in data_set_dict.items():
  62. path, pipe, data_set, vocab, warns = v
  63. if 'Chn' not in k:
  64. if warns:
  65. data_bundle = pipe(tokenizer='raw').process_from_file(path)
  66. else:
  67. data_bundle = pipe(tokenizer='raw').process_from_file(path)
  68. else:
  69. data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path)
  70. assert(isinstance(data_bundle, DataBundle))
  71. assert(len(data_set) == data_bundle.num_dataset)
  72. for name, dataset in data_bundle.iter_datasets():
  73. assert(name in data_set.keys())
  74. assert(data_set[name] == len(dataset))
  75. assert(len(vocab) == data_bundle.num_vocab)
  76. for name, vocabs in data_bundle.iter_vocabs():
  77. assert(name in vocab.keys())
  78. assert(vocab[name] == len(vocabs))