You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run.py 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import os
  2. import sys
  3. sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
  4. from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
  5. from fastNLP.core.trainer import SeqLabelTrainer
  6. from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader
  7. from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle
  8. from fastNLP.saver.model_saver import ModelSaver
  9. from fastNLP.loader.model_loader import ModelLoader
  10. from fastNLP.core.tester import SeqLabelTester
  11. from fastNLP.models.sequence_modeling import AdvSeqLabel
  12. from fastNLP.core.predictor import SeqLabelInfer
  13. # not in the file's dir
  14. if len(os.path.dirname(__file__)) != 0:
  15. os.chdir(os.path.dirname(__file__))
  16. datadir = "/home/zyfeng/data/"
  17. cfgfile = './cws.cfg'
  18. data_name = "pku_training.utf8"
  19. cws_data_path = os.path.join(datadir, "pku_training.utf8")
  20. pickle_path = "save"
  21. data_infer_path = os.path.join(datadir, "infer.utf8")
  22. def infer():
  23. # Config Loader
  24. test_args = ConfigSection()
  25. ConfigLoader("config", "").load_config(cfgfile, {"POS_test": test_args})
  26. # fetch dictionary size and number of labels from pickle files
  27. word2index = load_pickle(pickle_path, "word2id.pkl")
  28. test_args["vocab_size"] = len(word2index)
  29. index2label = load_pickle(pickle_path, "id2class.pkl")
  30. test_args["num_classes"] = len(index2label)
  31. # Define the same model
  32. model = AdvSeqLabel(test_args)
  33. try:
  34. ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
  35. print('model loaded!')
  36. except Exception as e:
  37. print('cannot load model!')
  38. raise
  39. # Data Loader
  40. raw_data_loader = BaseLoader(data_name, data_infer_path)
  41. infer_data = raw_data_loader.load_lines()
  42. print('data loaded')
  43. # Inference interface
  44. infer = SeqLabelInfer(pickle_path)
  45. results = infer.predict(model, infer_data)
  46. print(results)
  47. print("Inference finished!")
  48. def train():
  49. # Config Loader
  50. train_args = ConfigSection()
  51. test_args = ConfigSection()
  52. ConfigLoader("good_name", "good_path").load_config(cfgfile, {"train": train_args, "test": test_args})
  53. # Data Loader
  54. loader = TokenizeDatasetLoader(data_name, cws_data_path)
  55. train_data = loader.load_pku()
  56. # Preprocessor
  57. preprocessor = SeqLabelPreprocess()
  58. data_train, data_dev = preprocessor.run(train_data, pickle_path=pickle_path, train_dev_split=0.3)
  59. train_args["vocab_size"] = preprocessor.vocab_size
  60. train_args["num_classes"] = preprocessor.num_classes
  61. # Trainer
  62. trainer = SeqLabelTrainer(**train_args.data)
  63. # Model
  64. model = AdvSeqLabel(train_args)
  65. try:
  66. ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
  67. print('model parameter loaded!')
  68. except Exception as e:
  69. print("No saved model. Continue.")
  70. pass
  71. # Start training
  72. trainer.train(model, data_train, data_dev)
  73. print("Training finished!")
  74. # Saver
  75. saver = ModelSaver("./save/saved_model.pkl")
  76. saver.save_pytorch(model)
  77. print("Model saved!")
  78. def test():
  79. # Config Loader
  80. test_args = ConfigSection()
  81. ConfigLoader("config", "").load_config(cfgfile, {"POS_test": test_args})
  82. # fetch dictionary size and number of labels from pickle files
  83. word2index = load_pickle(pickle_path, "word2id.pkl")
  84. test_args["vocab_size"] = len(word2index)
  85. index2label = load_pickle(pickle_path, "id2class.pkl")
  86. test_args["num_classes"] = len(index2label)
  87. # load dev data
  88. dev_data = load_pickle(pickle_path, "data_dev.pkl")
  89. # Define the same model
  90. model = AdvSeqLabel(test_args)
  91. # Dump trained parameters into the model
  92. ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
  93. print("model loaded!")
  94. # Tester
  95. tester = SeqLabelTester(**test_args.data)
  96. # Start testing
  97. tester.test(model, dev_data)
  98. # print test results
  99. print(tester.show_matrices())
  100. print("model tested!")
  101. if __name__ == "__main__":
  102. import argparse
  103. parser = argparse.ArgumentParser(description='Run a chinese word segmentation model')
  104. parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer'])
  105. args = parser.parse_args()
  106. if args.mode == 'train':
  107. train()
  108. elif args.mode == 'test':
  109. test()
  110. elif args.mode == 'infer':
  111. infer()
  112. else:
  113. print('no mode specified for model!')
  114. parser.print_help()

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等