You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cws_train.py 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import sys
  2. sys.path.append("..")
  3. from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
  4. from fastNLP.core.trainer import SeqLabelTrainer
  5. from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader
  6. from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle
  7. from fastNLP.saver.model_saver import ModelSaver
  8. from fastNLP.loader.model_loader import ModelLoader
  9. from fastNLP.core.tester import SeqLabelTester
  10. from fastNLP.models.sequence_modeling import SeqLabeling
  11. from fastNLP.core.predictor import Predictor
  12. data_name = "pku_training.utf8"
  13. cws_data_path = "/home/zyfeng/data/pku_training.utf8"
  14. pickle_path = "./save/"
  15. data_infer_path = "/home/zyfeng/data/pku_test.utf8"
  16. def infer():
  17. # Load infer configuration, the same as test
  18. test_args = ConfigSection()
  19. ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
  20. # fetch dictionary size and number of labels from pickle files
  21. word2index = load_pickle(pickle_path, "word2id.pkl")
  22. test_args["vocab_size"] = len(word2index)
  23. index2label = load_pickle(pickle_path, "id2class.pkl")
  24. test_args["num_classes"] = len(index2label)
  25. # Define the same model
  26. model = SeqLabeling(test_args)
  27. # Dump trained parameters into the model
  28. ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl")
  29. print("model loaded!")
  30. # Data Loader
  31. raw_data_loader = BaseLoader(data_name, data_infer_path)
  32. infer_data = raw_data_loader.load_lines()
  33. # Inference interface
  34. infer = Predictor(pickle_path)
  35. results = infer.predict(model, infer_data)
  36. print(results)
  37. print("Inference finished!")
  38. def train_test():
  39. # Config Loader
  40. train_args = ConfigSection()
  41. test_args = ConfigSection()
  42. ConfigLoader("good_name", "good_path").load_config("./cws.cfg", {"train": train_args, "test": test_args})
  43. # Data Loader
  44. loader = TokenizeDatasetLoader(data_name, cws_data_path)
  45. train_data = loader.load_pku()
  46. # Preprocessor
  47. preprocess = SeqLabelPreprocess()
  48. data_train, data_dev = preprocess.run(train_data, pickle_path=pickle_path, train_dev_split=0.3)
  49. train_args["vocab_size"] = preprocess.vocab_size
  50. train_args["num_classes"] = preprocess.num_classes
  51. # Trainer
  52. trainer = SeqLabelTrainer(train_args)
  53. # Model
  54. model = SeqLabeling(train_args)
  55. # Start training
  56. trainer.train(model, data_train, data_dev)
  57. print("Training finished!")
  58. # Saver
  59. saver = ModelSaver("./save/saved_model.pkl")
  60. saver.save_pytorch(model)
  61. print("Model saved!")
  62. # testing with validation set
  63. test(data_dev)
  64. def test(test_data):
  65. # Config Loader
  66. train_args = ConfigSection()
  67. ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args})
  68. # Define the same model
  69. model = SeqLabeling(train_args)
  70. # Dump trained parameters into the model
  71. ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl")
  72. print("model loaded!")
  73. # Load test configuration
  74. test_args = ConfigSection()
  75. ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
  76. # Tester
  77. tester = SeqLabelTester(test_args)
  78. # Start testing
  79. tester.test(model, test_data)
  80. # print test results
  81. print(tester.show_matrices())
  82. print("model tested!")
  83. if __name__ == "__main__":
  84. train_test()

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等