You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cws.py 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import os
  2. from fastNLP.core.dataset import DataSet
  3. from fastNLP.core.vocabulary import Vocabulary
  4. from fastNLP.core.metrics import SeqLabelEvaluator
  5. from fastNLP.core.predictor import SeqLabelInfer
  6. from fastNLP.core.preprocess import save_pickle, load_pickle
  7. from fastNLP.core.tester import SeqLabelTester
  8. from fastNLP.core.trainer import SeqLabelTrainer
  9. from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
  10. from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader
  11. from fastNLP.loader.model_loader import ModelLoader
  12. from fastNLP.models.sequence_modeling import SeqLabeling
  13. from fastNLP.saver.model_saver import ModelSaver
  14. data_name = "pku_training.utf8"
  15. cws_data_path = "./test/data_for_tests/cws_pku_utf_8"
  16. pickle_path = "./save/"
  17. data_infer_path = "./test/data_for_tests/people_infer.txt"
  18. config_path = "./test/data_for_tests/config"
  19. def infer():
  20. # Load infer configuration, the same as test
  21. test_args = ConfigSection()
  22. ConfigLoader().load_config(config_path, {"POS_infer": test_args})
  23. # fetch dictionary size and number of labels from pickle files
  24. word2index = load_pickle(pickle_path, "word2id.pkl")
  25. test_args["vocab_size"] = len(word2index)
  26. index2label = load_pickle(pickle_path, "label2id.pkl")
  27. test_args["num_classes"] = len(index2label)
  28. # Define the same model
  29. model = SeqLabeling(test_args)
  30. # Dump trained parameters into the model
  31. ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
  32. print("model loaded!")
  33. # Load infer data
  34. infer_data = TokenizeDataSetLoader().load(data_infer_path)
  35. infer_data.index_field("word_seq", word2index)
  36. # inference
  37. infer = SeqLabelInfer(pickle_path)
  38. results = infer.predict(model, infer_data)
  39. print(results)
  40. def train_test():
  41. # Config Loader
  42. train_args = ConfigSection()
  43. ConfigLoader().load_config(config_path, {"POS_infer": train_args})
  44. # define dataset
  45. data_train = TokenizeDataSetLoader().load(cws_data_path)
  46. word_vocab = Vocabulary()
  47. label_vocab = Vocabulary()
  48. data_train.update_vocab(word_seq=word_vocab, label_seq=label_vocab)
  49. train_args["vocab_size"] = len(word_vocab)
  50. train_args["num_classes"] = len(label_vocab)
  51. save_pickle(word_vocab, pickle_path, "word2id.pkl")
  52. save_pickle(label_vocab, pickle_path, "label2id.pkl")
  53. # Trainer
  54. trainer = SeqLabelTrainer(**train_args.data)
  55. # Model
  56. model = SeqLabeling(train_args)
  57. # Start training
  58. trainer.train(model, data_train)
  59. # Saver
  60. saver = ModelSaver("./save/saved_model.pkl")
  61. saver.save_pytorch(model)
  62. del model, trainer
  63. # Define the same model
  64. model = SeqLabeling(train_args)
  65. # Dump trained parameters into the model
  66. ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
  67. # Load test configuration
  68. test_args = ConfigSection()
  69. ConfigLoader().load_config(config_path, {"POS_infer": test_args})
  70. test_args["evaluator"] = SeqLabelEvaluator()
  71. # Tester
  72. tester = SeqLabelTester(**test_args.data)
  73. # Start testing
  74. data_train.set_target(truth=True)
  75. tester.test(model, data_train)
  76. def test():
  77. os.makedirs("save", exist_ok=True)
  78. train_test()
  79. infer()
  80. os.system("rm -rf save")
  81. if __name__ == "__main__":
  82. train_test()
  83. infer()