You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

seq_labeling.py 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import os
  2. import sys
  3. sys.path.append("..")
  4. import argparse
  5. from fastNLP.io.config_loader import ConfigLoader, ConfigSection
  6. from fastNLP.io.dataset_loader import BaseLoader
  7. from fastNLP.io.model_saver import ModelSaver
  8. from fastNLP.io.model_loader import ModelLoader
  9. from fastNLP.core.tester import SeqLabelTester
  10. from fastNLP.models.sequence_modeling import SeqLabeling
  11. from fastNLP.core.predictor import SeqLabelInfer
  12. from fastNLP.core.optimizer import Optimizer
  13. from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target
  14. from fastNLP.core.metrics import SeqLabelEvaluator
  15. from fastNLP.core.utils import save_pickle, load_pickle
  16. parser = argparse.ArgumentParser()
  17. parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files")
  18. parser.add_argument("-t", "--train", type=str, default="../data_for_tests/people.txt",
  19. help="path to the training data")
  20. parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file")
  21. parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model")
  22. parser.add_argument("-i", "--infer", type=str, default="../data_for_tests/people_infer.txt",
  23. help="data used for inference")
  24. args = parser.parse_args()
  25. pickle_path = args.save
  26. model_name = args.model_name
  27. config_dir = args.config
  28. data_path = args.train
  29. data_infer_path = args.infer
  30. def infer():
  31. # Load infer configuration, the same as test
  32. test_args = ConfigSection()
  33. ConfigLoader().load_config(config_dir, {"POS_infer": test_args})
  34. # fetch dictionary size and number of labels from pickle files
  35. word_vocab = load_pickle(pickle_path, "word2id.pkl")
  36. label_vocab = load_pickle(pickle_path, "label2id.pkl")
  37. test_args["vocab_size"] = len(word_vocab)
  38. test_args["num_classes"] = len(label_vocab)
  39. print("vocabularies loaded")
  40. # Define the same model
  41. model = SeqLabeling(test_args)
  42. print("model defined")
  43. # Dump trained parameters into the model
  44. ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))
  45. print("model loaded!")
  46. # Data Loader
  47. infer_data = SeqLabelDataSet(load_func=BaseLoader.load)
  48. infer_data.load(data_infer_path, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}, infer=True)
  49. print("data set prepared")
  50. # Inference interface
  51. infer = SeqLabelInfer(pickle_path)
  52. results = infer.predict(model, infer_data)
  53. for res in results:
  54. print(res)
  55. print("Inference finished!")
  56. def train_and_test():
  57. # Config Loader
  58. trainer_args = ConfigSection()
  59. model_args = ConfigSection()
  60. ConfigLoader().load_config(config_dir, {
  61. "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args})
  62. data_set = SeqLabelDataSet()
  63. data_set.load(data_path)
  64. train_set, dev_set = data_set.split(0.3, shuffle=True)
  65. model_args["vocab_size"] = len(data_set.word_vocab)
  66. model_args["num_classes"] = len(data_set.label_vocab)
  67. save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl")
  68. save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl")
  69. """
  70. trainer = SeqLabelTrainer(
  71. epochs=trainer_args["epochs"],
  72. batch_size=trainer_args["batch_size"],
  73. validate=False,
  74. use_cuda=trainer_args["use_cuda"],
  75. pickle_path=pickle_path,
  76. save_best_dev=trainer_args["save_best_dev"],
  77. model_name=model_name,
  78. optimizer=Optimizer("SGD", lr=0.01, momentum=0.9),
  79. )
  80. """
  81. # Model
  82. model = SeqLabeling(model_args)
  83. model.fit(train_set, dev_set,
  84. epochs=trainer_args["epochs"],
  85. batch_size=trainer_args["batch_size"],
  86. validate=False,
  87. use_cuda=trainer_args["use_cuda"],
  88. pickle_path=pickle_path,
  89. save_best_dev=trainer_args["save_best_dev"],
  90. model_name=model_name,
  91. optimizer=Optimizer("SGD", lr=0.01, momentum=0.9))
  92. # Start training
  93. # trainer.train(model, train_set, dev_set)
  94. print("Training finished!")
  95. # Saver
  96. saver = ModelSaver(os.path.join(pickle_path, model_name))
  97. saver.save_pytorch(model)
  98. print("Model saved!")
  99. del model
  100. change_field_is_target(dev_set, "truth", True)
  101. # Define the same model
  102. model = SeqLabeling(model_args)
  103. # Dump trained parameters into the model
  104. ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))
  105. print("model loaded!")
  106. # Load test configuration
  107. tester_args = ConfigSection()
  108. ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args})
  109. # Tester
  110. tester = SeqLabelTester(batch_size=4,
  111. use_cuda=False,
  112. pickle_path=pickle_path,
  113. model_name="seq_label_in_test.pkl",
  114. evaluator=SeqLabelEvaluator()
  115. )
  116. # Start testing with validation data
  117. tester.test(model, dev_set)
  118. print("model tested!")
  119. if __name__ == "__main__":
  120. train_and_test()
  121. infer()