You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_POS_pipeline.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import sys
  2. sys.path.append("..")
  3. from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
  4. from fastNLP.action.trainer import POSTrainer
  5. from fastNLP.loader.dataset_loader import POSDatasetLoader
  6. from fastNLP.loader.preprocess import POSPreprocess
  7. from fastNLP.saver.model_saver import ModelSaver
  8. from fastNLP.loader.model_loader import ModelLoader
  9. from fastNLP.action.tester import POSTester
  10. from fastNLP.models.sequence_modeling import SeqLabeling
  11. from fastNLP.action.inference import Inference
  12. data_name = "people.txt"
  13. data_path = "data_for_tests/people.txt"
  14. pickle_path = "data_for_tests"
  15. def test_infer():
  16. # Define the same model
  17. model = SeqLabeling(hidden_dim=train_args["rnn_hidden_units"], rnn_num_layer=train_args["rnn_layers"],
  18. num_classes=train_args["num_classes"], vocab_size=train_args["vocab_size"],
  19. word_emb_dim=train_args["word_emb_dim"], bi_direction=train_args["rnn_bi_direction"],
  20. rnn_mode="gru", dropout=train_args["dropout"], use_crf=train_args["use_crf"])
  21. # Dump trained parameters into the model
  22. ModelLoader("arbitrary_name", "./saved_model.pkl").load_pytorch(model)
  23. print("model loaded!")
  24. # Data Loader
  25. pos_loader = POSDatasetLoader(data_name, data_path)
  26. infer_data = pos_loader.load_lines()
  27. # Preprocessor
  28. POSPreprocess(infer_data, pickle_path)
  29. # Inference interface
  30. infer = Inference()
  31. results = infer.predict(model, infer_data)
  32. if __name__ == "__main__":
  33. # Config Loader
  34. train_args = ConfigSection()
  35. ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args})
  36. # Data Loader
  37. pos_loader = POSDatasetLoader(data_name, data_path)
  38. train_data = pos_loader.load_lines()
  39. # Preprocessor
  40. p = POSPreprocess(train_data, pickle_path)
  41. train_args["vocab_size"] = p.vocab_size
  42. train_args["num_classes"] = p.num_classes
  43. # Trainer
  44. trainer = POSTrainer(train_args)
  45. # Model
  46. model = SeqLabeling(hidden_dim=train_args["rnn_hidden_units"], rnn_num_layer=train_args["rnn_layers"],
  47. num_classes=train_args["num_classes"], vocab_size=train_args["vocab_size"],
  48. word_emb_dim=train_args["word_emb_dim"], bi_direction=train_args["rnn_bi_direction"],
  49. rnn_mode="gru", dropout=train_args["dropout"], use_crf=train_args["use_crf"])
  50. # Start training
  51. trainer.train(model)
  52. print("Training finished!")
  53. # Saver
  54. saver = ModelSaver("./saved_model.pkl")
  55. saver.save_pytorch(model)
  56. print("Model saved!")
  57. del model, trainer, pos_loader
  58. # Define the same model
  59. model = SeqLabeling(hidden_dim=train_args["rnn_hidden_units"], rnn_num_layer=train_args["rnn_layers"],
  60. num_classes=train_args["num_classes"], vocab_size=train_args["vocab_size"],
  61. word_emb_dim=train_args["word_emb_dim"], bi_direction=train_args["rnn_bi_direction"],
  62. rnn_mode="gru", dropout=train_args["dropout"], use_crf=train_args["use_crf"])
  63. # Dump trained parameters into the model
  64. ModelLoader("arbitrary_name", "./saved_model.pkl").load_pytorch(model)
  65. print("model loaded!")
  66. # Load test configuration
  67. test_args = ConfigSection()
  68. ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
  69. # Tester
  70. tester = POSTester(test_args)
  71. # Start testing
  72. tester.test(model)
  73. # print test results
  74. print(tester.show_matrices())
  75. print("model tested!")

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等