You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cws.py 3.5 kB

7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import os
  2. from fastNLP.core.metrics import SeqLabelEvaluator
  3. from fastNLP.core.predictor import SeqLabelInfer
  4. from fastNLP.core.tester import SeqLabelTester
  5. from fastNLP.core.trainer import SeqLabelTrainer
  6. from fastNLP.core.utils import save_pickle, load_pickle
  7. from fastNLP.core.vocabulary import Vocabulary
  8. from fastNLP.io.config_loader import ConfigLoader, ConfigSection
  9. from fastNLP.io.dataset_loader import TokenizeDataSetLoader, RawDataSetLoader
  10. from fastNLP.io.model_loader import ModelLoader
  11. from fastNLP.io.model_saver import ModelSaver
  12. from fastNLP.models.sequence_modeling import SeqLabeling
  13. data_name = "pku_training.utf8"
  14. cws_data_path = "./test/data_for_tests/cws_pku_utf_8"
  15. pickle_path = "./save/"
  16. data_infer_path = "./test/data_for_tests/people_infer.txt"
  17. config_path = "./test/data_for_tests/config"
  18. def infer():
  19. # Load infer configuration, the same as test
  20. test_args = ConfigSection()
  21. ConfigLoader().load_config(config_path, {"POS_infer": test_args})
  22. # fetch dictionary size and number of labels from pickle files
  23. word2index = load_pickle(pickle_path, "word2id.pkl")
  24. test_args["vocab_size"] = len(word2index)
  25. index2label = load_pickle(pickle_path, "label2id.pkl")
  26. test_args["num_classes"] = len(index2label)
  27. # Define the same model
  28. model = SeqLabeling(test_args)
  29. # Dump trained parameters into the model
  30. ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
  31. print("model loaded!")
  32. # Load infer data
  33. infer_data = RawDataSetLoader().load(data_infer_path)
  34. infer_data.index_field("word_seq", word2index)
  35. infer_data.set_origin_len("word_seq")
  36. # inference
  37. infer = SeqLabelInfer(pickle_path)
  38. results = infer.predict(model, infer_data)
  39. print(results)
  40. def train_test():
  41. # Config Loader
  42. train_args = ConfigSection()
  43. ConfigLoader().load_config(config_path, {"POS_infer": train_args})
  44. # define dataset
  45. data_train = TokenizeDataSetLoader().load(cws_data_path)
  46. word_vocab = Vocabulary()
  47. label_vocab = Vocabulary()
  48. data_train.update_vocab(word_seq=word_vocab, label_seq=label_vocab)
  49. data_train.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab)
  50. data_train.set_origin_len("word_seq")
  51. data_train.rename_field("label_seq", "truth").set_target(truth=False)
  52. train_args["vocab_size"] = len(word_vocab)
  53. train_args["num_classes"] = len(label_vocab)
  54. save_pickle(word_vocab, pickle_path, "word2id.pkl")
  55. save_pickle(label_vocab, pickle_path, "label2id.pkl")
  56. # Trainer
  57. trainer = SeqLabelTrainer(**train_args.data)
  58. # Model
  59. model = SeqLabeling(train_args)
  60. # Start training
  61. trainer.train(model, data_train)
  62. # Saver
  63. saver = ModelSaver("./save/saved_model.pkl")
  64. saver.save_pytorch(model)
  65. del model, trainer
  66. # Define the same model
  67. model = SeqLabeling(train_args)
  68. # Dump trained parameters into the model
  69. ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
  70. # Load test configuration
  71. test_args = ConfigSection()
  72. ConfigLoader().load_config(config_path, {"POS_infer": test_args})
  73. test_args["evaluator"] = SeqLabelEvaluator()
  74. # Tester
  75. tester = SeqLabelTester(**test_args.data)
  76. # Start testing
  77. data_train.set_target(truth=True)
  78. tester.test(model, data_train)
  79. def test():
  80. os.makedirs("save", exist_ok=True)
  81. train_test()
  82. infer()
  83. os.system("rm -rf save")
  84. if __name__ == "__main__":
  85. train_test()
  86. infer()