You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run.py 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import os
  2. import sys
  3. sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
  4. from fastNLP.io.config_io import ConfigLoader, ConfigSection
  5. from fastNLP.core.trainer import SeqLabelTrainer
  6. from fastNLP.io.dataset_loader import BaseLoader, TokenizeDataSetLoader
  7. from fastNLP.core.utils import load_pickle
  8. from fastNLP.io.model_io import ModelLoader, ModelSaver
  9. from fastNLP.core.tester import SeqLabelTester
  10. from fastNLP.models.sequence_modeling import AdvSeqLabel
  11. from fastNLP.core.predictor import SeqLabelInfer
  12. from fastNLP.core.utils import save_pickle
  13. from fastNLP.core.metrics import SeqLabelEvaluator
  14. # not in the file's dir
  15. if len(os.path.dirname(__file__)) != 0:
  16. os.chdir(os.path.dirname(__file__))
  17. datadir = "/home/zyfeng/data/"
  18. cfgfile = './cws.cfg'
  19. cws_data_path = os.path.join(datadir, "pku_training.utf8")
  20. pickle_path = "save"
  21. data_infer_path = os.path.join(datadir, "infer.utf8")
  22. def infer():
  23. # Config Loader
  24. test_args = ConfigSection()
  25. ConfigLoader().load_config(cfgfile, {"POS_test": test_args})
  26. # fetch dictionary size and number of labels from pickle files
  27. word2index = load_pickle(pickle_path, "word2id.pkl")
  28. test_args["vocab_size"] = len(word2index)
  29. index2label = load_pickle(pickle_path, "label2id.pkl")
  30. test_args["num_classes"] = len(index2label)
  31. # Define the same model
  32. model = AdvSeqLabel(test_args)
  33. try:
  34. ModelLoader.load_pytorch(model, "./save/trained_model.pkl")
  35. print('model loaded!')
  36. except Exception as e:
  37. print('cannot load model!')
  38. raise
  39. # Data Loader
  40. infer_data = SeqLabelDataSet(load_func=BaseLoader.load_lines)
  41. infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True)
  42. print('data loaded')
  43. # Inference interface
  44. infer = SeqLabelInfer(pickle_path)
  45. results = infer.predict(model, infer_data)
  46. print(results)
  47. print("Inference finished!")
  48. def train():
  49. # Config Loader
  50. train_args = ConfigSection()
  51. test_args = ConfigSection()
  52. ConfigLoader().load_config(cfgfile, {"train": train_args, "test": test_args})
  53. print("loading data set...")
  54. data = SeqLabelDataSet(load_func=TokenizeDataSetLoader.load)
  55. data.load(cws_data_path)
  56. data_train, data_dev = data.split(ratio=0.3)
  57. train_args["vocab_size"] = len(data.word_vocab)
  58. train_args["num_classes"] = len(data.label_vocab)
  59. print("vocab size={}, num_classes={}".format(len(data.word_vocab), len(data.label_vocab)))
  60. change_field_is_target(data_dev, "truth", True)
  61. save_pickle(data_dev, "./save/", "data_dev.pkl")
  62. save_pickle(data.word_vocab, "./save/", "word2id.pkl")
  63. save_pickle(data.label_vocab, "./save/", "label2id.pkl")
  64. # Trainer
  65. trainer = SeqLabelTrainer(epochs=train_args["epochs"], batch_size=train_args["batch_size"],
  66. validate=train_args["validate"],
  67. use_cuda=train_args["use_cuda"], pickle_path=train_args["pickle_path"],
  68. save_best_dev=True, print_every_step=10, model_name="trained_model.pkl",
  69. evaluator=SeqLabelEvaluator())
  70. # Model
  71. model = AdvSeqLabel(train_args)
  72. try:
  73. ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
  74. print('model parameter loaded!')
  75. except Exception as e:
  76. print("No saved model. Continue.")
  77. pass
  78. # Start training
  79. trainer.train(model, data_train, data_dev)
  80. print("Training finished!")
  81. # Saver
  82. saver = ModelSaver("./save/trained_model.pkl")
  83. saver.save_pytorch(model)
  84. print("Model saved!")
  85. def predict():
  86. # Config Loader
  87. test_args = ConfigSection()
  88. ConfigLoader().load_config(cfgfile, {"POS_test": test_args})
  89. # fetch dictionary size and number of labels from pickle files
  90. word2index = load_pickle(pickle_path, "word2id.pkl")
  91. test_args["vocab_size"] = len(word2index)
  92. index2label = load_pickle(pickle_path, "label2id.pkl")
  93. test_args["num_classes"] = len(index2label)
  94. # load dev data
  95. dev_data = load_pickle(pickle_path, "data_dev.pkl")
  96. # Define the same model
  97. model = AdvSeqLabel(test_args)
  98. # Dump trained parameters into the model
  99. ModelLoader.load_pytorch(model, "./save/trained_model.pkl")
  100. print("model loaded!")
  101. # Tester
  102. test_args["evaluator"] = SeqLabelEvaluator()
  103. tester = SeqLabelTester(**test_args.data)
  104. # Start testing
  105. tester.test(model, dev_data)
  106. if __name__ == "__main__":
  107. import argparse
  108. parser = argparse.ArgumentParser(description='Run a chinese word segmentation model')
  109. parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer'])
  110. args = parser.parse_args()
  111. if args.mode == 'train':
  112. train()
  113. elif args.mode == 'test':
  114. predict()
  115. elif args.mode == 'infer':
  116. infer()
  117. else:
  118. print('no mode specified for model!')
  119. parser.print_help()