You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cws.py 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import os
  2. from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target
  3. from fastNLP.core.metrics import SeqLabelEvaluator
  4. from fastNLP.core.predictor import SeqLabelInfer
  5. from fastNLP.core.preprocess import save_pickle, load_pickle
  6. from fastNLP.core.tester import SeqLabelTester
  7. from fastNLP.core.trainer import SeqLabelTrainer
  8. from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
  9. from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader
  10. from fastNLP.loader.model_loader import ModelLoader
  11. from fastNLP.models.sequence_modeling import SeqLabeling
  12. from fastNLP.saver.model_saver import ModelSaver
  13. data_name = "pku_training.utf8"
  14. cws_data_path = "./test/data_for_tests/cws_pku_utf_8"
  15. pickle_path = "./save/"
  16. data_infer_path = "./test/data_for_tests/people_infer.txt"
  17. config_path = "./test/data_for_tests/config"
  18. def infer():
  19. # Load infer configuration, the same as test
  20. test_args = ConfigSection()
  21. ConfigLoader().load_config(config_path, {"POS_infer": test_args})
  22. # fetch dictionary size and number of labels from pickle files
  23. word2index = load_pickle(pickle_path, "word2id.pkl")
  24. test_args["vocab_size"] = len(word2index)
  25. index2label = load_pickle(pickle_path, "label2id.pkl")
  26. test_args["num_classes"] = len(index2label)
  27. # Define the same model
  28. model = SeqLabeling(test_args)
  29. # Dump trained parameters into the model
  30. ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
  31. print("model loaded!")
  32. # Load infer data
  33. infer_data = SeqLabelDataSet(loader=BaseLoader())
  34. infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True)
  35. # inference
  36. infer = SeqLabelInfer(pickle_path)
  37. results = infer.predict(model, infer_data)
  38. print(results)
  39. def train_test():
  40. # Config Loader
  41. train_args = ConfigSection()
  42. ConfigLoader().load_config(config_path, {"POS_infer": train_args})
  43. # define dataset
  44. data_train = SeqLabelDataSet(loader=TokenizeDataSetLoader())
  45. data_train.load(cws_data_path)
  46. train_args["vocab_size"] = len(data_train.word_vocab)
  47. train_args["num_classes"] = len(data_train.label_vocab)
  48. save_pickle(data_train.word_vocab, pickle_path, "word2id.pkl")
  49. save_pickle(data_train.label_vocab, pickle_path, "label2id.pkl")
  50. # Trainer
  51. trainer = SeqLabelTrainer(**train_args.data)
  52. # Model
  53. model = SeqLabeling(train_args)
  54. # Start training
  55. trainer.train(model, data_train)
  56. # Saver
  57. saver = ModelSaver("./save/saved_model.pkl")
  58. saver.save_pytorch(model)
  59. del model, trainer
  60. # Define the same model
  61. model = SeqLabeling(train_args)
  62. # Dump trained parameters into the model
  63. ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
  64. # Load test configuration
  65. test_args = ConfigSection()
  66. ConfigLoader().load_config(config_path, {"POS_infer": test_args})
  67. test_args["evaluator"] = SeqLabelEvaluator()
  68. # Tester
  69. tester = SeqLabelTester(**test_args.data)
  70. # Start testing
  71. change_field_is_target(data_train, "truth", True)
  72. tester.test(model, data_train)
  73. def test():
  74. os.makedirs("save", exist_ok=True)
  75. train_test()
  76. infer()
  77. os.system("rm -rf save")
  78. if __name__ == "__main__":
  79. train_test()
  80. infer()