You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

seq_labeling.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import os
  2. import sys
  3. sys.path.append("..")
  4. import argparse
  5. from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
  6. from fastNLP.core.trainer import SeqLabelTrainer
  7. from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader
  8. from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle
  9. from fastNLP.saver.model_saver import ModelSaver
  10. from fastNLP.loader.model_loader import ModelLoader
  11. from fastNLP.core.tester import SeqLabelTester
  12. from fastNLP.models.sequence_modeling import SeqLabeling
  13. from fastNLP.core.predictor import SeqLabelInfer
  14. from fastNLP.core.optimizer import Optimizer
  15. parser = argparse.ArgumentParser()
  16. parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files")
  17. parser.add_argument("-t", "--train", type=str, default="./data_for_tests/people.txt",
  18. help="path to the training data")
  19. parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file")
  20. parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model")
  21. parser.add_argument("-i", "--infer", type=str, default="data_for_tests/people_infer.txt",
  22. help="data used for inference")
  23. args = parser.parse_args()
  24. pickle_path = args.save
  25. model_name = args.model_name
  26. config_dir = args.config
  27. data_path = args.train
  28. data_infer_path = args.infer
  29. def infer():
  30. # Load infer configuration, the same as test
  31. test_args = ConfigSection()
  32. ConfigLoader("config.cfg", "").load_config(config_dir, {"POS_infer": test_args})
  33. # fetch dictionary size and number of labels from pickle files
  34. word2index = load_pickle(pickle_path, "word2id.pkl")
  35. test_args["vocab_size"] = len(word2index)
  36. index2label = load_pickle(pickle_path, "id2class.pkl")
  37. test_args["num_classes"] = len(index2label)
  38. # Define the same model
  39. model = SeqLabeling(test_args)
  40. # Dump trained parameters into the model
  41. ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))
  42. print("model loaded!")
  43. # Data Loader
  44. raw_data_loader = BaseLoader("xxx", data_infer_path)
  45. infer_data = raw_data_loader.load_lines()
  46. # Inference interface
  47. infer = SeqLabelInfer(pickle_path)
  48. results = infer.predict(model, infer_data)
  49. for res in results:
  50. print(res)
  51. print("Inference finished!")
  52. def train_and_test():
  53. # Config Loader
  54. trainer_args = ConfigSection()
  55. model_args = ConfigSection()
  56. ConfigLoader("config.cfg", "").load_config(config_dir, {
  57. "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args})
  58. # Data Loader
  59. pos_loader = POSDatasetLoader("xxx", data_path)
  60. train_data = pos_loader.load_lines()
  61. # Preprocessor
  62. p = SeqLabelPreprocess()
  63. data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5)
  64. model_args["vocab_size"] = p.vocab_size
  65. model_args["num_classes"] = p.num_classes
  66. # Trainer: two definition styles
  67. # 1
  68. # trainer = SeqLabelTrainer(trainer_args.data)
  69. # 2
  70. trainer = SeqLabelTrainer(
  71. epochs=trainer_args["epochs"],
  72. batch_size=trainer_args["batch_size"],
  73. validate=trainer_args["validate"],
  74. use_cuda=trainer_args["use_cuda"],
  75. pickle_path=pickle_path,
  76. save_best_dev=trainer_args["save_best_dev"],
  77. model_name=model_name,
  78. optimizer=Optimizer("SGD", lr=0.01, momentum=0.9),
  79. )
  80. # Model
  81. model = SeqLabeling(model_args)
  82. # Start training
  83. trainer.train(model, data_train, data_dev)
  84. print("Training finished!")
  85. # Saver
  86. saver = ModelSaver(os.path.join(pickle_path, model_name))
  87. saver.save_pytorch(model)
  88. print("Model saved!")
  89. del model, trainer, pos_loader
  90. # Define the same model
  91. model = SeqLabeling(model_args)
  92. # Dump trained parameters into the model
  93. ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))
  94. print("model loaded!")
  95. # Load test configuration
  96. tester_args = ConfigSection()
  97. ConfigLoader("config.cfg", "").load_config(config_dir, {"test_seq_label_tester": tester_args})
  98. # Tester
  99. tester = SeqLabelTester(save_output=False,
  100. save_loss=False,
  101. save_best_dev=False,
  102. batch_size=8,
  103. use_cuda=False,
  104. pickle_path=pickle_path,
  105. model_name="seq_label_in_test.pkl",
  106. print_every_step=1
  107. )
  108. # Start testing with validation data
  109. tester.test(model, data_dev)
  110. # print test results
  111. print(tester.show_matrices())
  112. print("model tested!")
  113. if __name__ == "__main__":
  114. train_and_test()
  115. infer()

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等