You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tester.py 1.3 kB

12345678910111213141516171819202122232425262728293031323334353637
  1. from fastNLP.core.preprocess import SeqLabelPreprocess
  2. from fastNLP.core.tester import SeqLabelTester
  3. from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
  4. from fastNLP.loader.dataset_loader import TokenizeDatasetLoader
  5. from fastNLP.models.sequence_modeling import SeqLabeling
  6. data_name = "pku_training.utf8"
  7. cws_data_path = "/home/zyfeng/Desktop/data/pku_training.utf8"
  8. pickle_path = "data_for_tests"
  9. def foo():
  10. loader = TokenizeDatasetLoader(data_name, "./data_for_tests/cws_pku_utf_8")
  11. train_data = loader.load_pku()
  12. train_args = ConfigSection()
  13. ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args})
  14. # Preprocessor
  15. p = SeqLabelPreprocess(train_data, pickle_path)
  16. train_args["vocab_size"] = p.vocab_size
  17. train_args["num_classes"] = p.num_classes
  18. model = SeqLabeling(train_args)
  19. valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
  20. "save_loss": True, "batch_size": 8, "pickle_path": "./data_for_tests/",
  21. "use_cuda": True}
  22. validator = SeqLabelTester(valid_args)
  23. print("start validation.")
  24. validator.test(model)
  25. print(validator.show_matrices())
  26. if __name__ == "__main__":
  27. foo()

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等