You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_POS_pipeline.py 1.1 kB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import sys
  2. sys.path.append("..")
  3. from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
  4. from fastNLP.action.trainer import POSTrainer
  5. from fastNLP.loader.dataset_loader import POSDatasetLoader
  6. from fastNLP.loader.preprocess import POSPreprocess
  7. from fastNLP.models.sequence_modeling import SeqLabeling
  8. data_name = "people.txt"
  9. data_path = "data_for_tests/people.txt"
  10. pickle_path = "data_for_tests"
  11. if __name__ == "__main__":
  12. train_args = ConfigSection()
  13. ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args})
  14. # Data Loader
  15. pos = POSDatasetLoader(data_name, data_path)
  16. train_data = pos.load_lines()
  17. # Preprocessor
  18. p = POSPreprocess(train_data, pickle_path)
  19. vocab_size = p.vocab_size
  20. num_classes = p.num_classes
  21. train_args["vocab_size"] = vocab_size
  22. train_args["num_classes"] = num_classes
  23. trainer = POSTrainer(train_args)
  24. # Model
  25. model = SeqLabeling(100, 1, num_classes, vocab_size, bi_direction=True)
  26. # Start training
  27. trainer.train(model)
  28. print("Training finished!")

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等