You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tester.py 1.2 kB

12345678910111213141516171819202122232425262728293031323334353637
  1. from fastNLP.core.preprocess import SeqLabelPreprocess
  2. from fastNLP.core.tester import SeqLabelTester
  3. from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
  4. from fastNLP.loader.dataset_loader import TokenizeDatasetLoader
  5. from fastNLP.models.sequence_modeling import SeqLabeling
  6. data_name = "pku_training.utf8"
  7. pickle_path = "data_for_tests"
  8. def foo():
  9. loader = TokenizeDatasetLoader("./data_for_tests/cws_pku_utf_8")
  10. train_data = loader.load_pku()
  11. train_args = ConfigSection()
  12. ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS": train_args})
  13. # Preprocessor
  14. p = SeqLabelPreprocess()
  15. train_data = p.run(train_data)
  16. train_args["vocab_size"] = p.vocab_size
  17. train_args["num_classes"] = p.num_classes
  18. model = SeqLabeling(train_args)
  19. valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
  20. "save_loss": True, "batch_size": 8, "pickle_path": "./data_for_tests/",
  21. "use_cuda": True}
  22. validator = SeqLabelTester(**valid_args)
  23. print("start validation.")
  24. validator.test(model, train_data)
  25. print(validator.show_metrics())
  26. if __name__ == "__main__":
  27. foo()

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等