You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataLoader.py 1.4 kB

123456789101112131415161718192021222324252627282930313233343536
  1. import unittest
  2. import sys
  3. sys.path.append('..')
  4. from data.dataloader import SummarizationLoader
  5. vocab_size = 100000
  6. vocab_path = "testdata/vocab"
  7. sent_max_len = 100
  8. doc_max_timesteps = 50
  9. class TestSummarizationLoader(unittest.TestCase):
  10. def test_case1(self):
  11. sum_loader = SummarizationLoader()
  12. paths = {"train":"testdata/train.jsonl", "valid":"testdata/val.jsonl", "test":"testdata/test.jsonl"}
  13. data = sum_loader.process(paths=paths, vocab_size=vocab_size, vocab_path=vocab_path, sent_max_len=sent_max_len, doc_max_timesteps=doc_max_timesteps)
  14. print(data.datasets)
  15. def test_case2(self):
  16. sum_loader = SummarizationLoader()
  17. paths = {"train": "testdata/train.jsonl", "valid": "testdata/val.jsonl", "test": "testdata/test.jsonl"}
  18. data = sum_loader.process(paths=paths, vocab_size=vocab_size, vocab_path=vocab_path, sent_max_len=sent_max_len, doc_max_timesteps=doc_max_timesteps, domain=True)
  19. print(data.datasets, data.vocabs)
  20. def test_case3(self):
  21. sum_loader = SummarizationLoader()
  22. paths = {"train": "testdata/train.jsonl", "valid": "testdata/val.jsonl", "test": "testdata/test.jsonl"}
  23. data = sum_loader.process(paths=paths, vocab_size=vocab_size, vocab_path=vocab_path, sent_max_len=sent_max_len, doc_max_timesteps=doc_max_timesteps, tag=True)
  24. print(data.datasets, data.vocabs)