You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cws.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import pytest
  2. import os
  3. from fastNLP.io.pipe.cws import CWSPipe
  4. class TestCWSPipe:
  5. @pytest.mark.skipif('download' not in os.environ, reason="Skip download")
  6. def test_process_from_file(self):
  7. dataset_names = ['pku', 'cityu', 'as', 'msra']
  8. for dataset_name in dataset_names:
  9. data_bundle = CWSPipe(dataset_name=dataset_name).process_from_file()
  10. print(data_bundle)
  11. def test_demo(self):
  12. # related to issue https://github.com/fastnlp/fastNLP/issues/324#issue-705081091
  13. from fastNLP import DataSet, Instance
  14. from fastNLP.io import DataBundle
  15. data_bundle = DataBundle()
  16. ds = DataSet()
  17. ds.append(Instance(raw_words="截流 进入 最后 冲刺 ( 附 图片 1 张 )"))
  18. data_bundle.set_dataset(ds, name='train')
  19. data_bundle = CWSPipe().process(data_bundle)
  20. assert('<' not in data_bundle.get_vocab('chars'))
  21. class TestRunCWSPipe:
  22. def test_process_from_file(self):
  23. dataset_names = ['msra', 'cityu', 'as', 'pku']
  24. for dataset_name in dataset_names:
  25. data_bundle = CWSPipe(bigrams=True, trigrams=True, num_proc=0).\
  26. process_from_file(f'tests/data_for_tests/io/cws_{dataset_name}')
  27. print(data_bundle)
  28. def test_replace_number(self):
  29. data_bundle = CWSPipe(bigrams=True, replace_num_alpha=True, num_proc=0).\
  30. process_from_file(f'tests/data_for_tests/io/cws_pku')
  31. for word in ['<', '>', '<NUM>']:
  32. assert(data_bundle.get_vocab('chars').to_index(word) != 1)
  33. def test_process_from_file_proc(self):
  34. dataset_names = ['msra', 'cityu', 'as', 'pku']
  35. for dataset_name in dataset_names:
  36. data_bundle = CWSPipe(bigrams=True, trigrams=True, num_proc=2).\
  37. process_from_file(f'tests/data_for_tests/io/cws_{dataset_name}')
  38. print(data_bundle)
  39. def test_replace_number_proc(self):
  40. data_bundle = CWSPipe(bigrams=True, replace_num_alpha=True, num_proc=2).\
  41. process_from_file(f'tests/data_for_tests/io/cws_pku')
  42. for word in ['<', '>', '<NUM>']:
  43. assert(data_bundle.get_vocab('chars').to_index(word) != 1)