You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cws.py 3.1 kB

7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import os
  2. from fastNLP.core.predictor import Predictor
  3. from fastNLP.core.preprocess import Preprocessor, load_pickle
  4. from fastNLP.core.tester import SeqLabelTester
  5. from fastNLP.core.trainer import SeqLabelTrainer
  6. from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
  7. from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader
  8. from fastNLP.loader.model_loader import ModelLoader
  9. from fastNLP.models.sequence_modeling import SeqLabeling
  10. from fastNLP.saver.model_saver import ModelSaver
  11. data_name = "pku_training.utf8"
  12. cws_data_path = "test/data_for_tests/cws_pku_utf_8"
  13. pickle_path = "./save/"
  14. data_infer_path = "test/data_for_tests/people_infer.txt"
  15. config_path = "test/data_for_tests/config"
  16. def infer():
  17. # Load infer configuration, the same as test
  18. test_args = ConfigSection()
  19. ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": test_args})
  20. # fetch dictionary size and number of labels from pickle files
  21. word2index = load_pickle(pickle_path, "word2id.pkl")
  22. test_args["vocab_size"] = len(word2index)
  23. index2label = load_pickle(pickle_path, "class2id.pkl")
  24. test_args["num_classes"] = len(index2label)
  25. # Define the same model
  26. model = SeqLabeling(test_args)
  27. # Dump trained parameters into the model
  28. ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
  29. print("model loaded!")
  30. # Data Loader
  31. raw_data_loader = BaseLoader(data_infer_path)
  32. infer_data = raw_data_loader.load_lines()
  33. # Inference interface
  34. infer = Predictor(pickle_path, "seq_label")
  35. results = infer.predict(model, infer_data)
  36. print(results)
  37. def train_test():
  38. # Config Loader
  39. train_args = ConfigSection()
  40. ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": train_args})
  41. # Data Loader
  42. loader = TokenizeDatasetLoader(cws_data_path)
  43. train_data = loader.load_pku()
  44. # Preprocessor
  45. p = Preprocessor(label_is_seq=True)
  46. data_train = p.run(train_data, pickle_path=pickle_path)
  47. train_args["vocab_size"] = p.vocab_size
  48. train_args["num_classes"] = p.num_classes
  49. # Trainer
  50. trainer = SeqLabelTrainer(**train_args.data)
  51. # Model
  52. model = SeqLabeling(train_args)
  53. # Start training
  54. trainer.train(model, data_train)
  55. # Saver
  56. saver = ModelSaver("./save/saved_model.pkl")
  57. saver.save_pytorch(model)
  58. del model, trainer, loader
  59. # Define the same model
  60. model = SeqLabeling(train_args)
  61. # Dump trained parameters into the model
  62. ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
  63. # Load test configuration
  64. test_args = ConfigSection()
  65. ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": test_args})
  66. # Tester
  67. tester = SeqLabelTester(**test_args.data)
  68. # Start testing
  69. tester.test(model, data_train)
  70. # print test results
  71. print(tester.show_metrics())
  72. def test():
  73. os.makedirs("save", exist_ok=True)
  74. train_test()
  75. infer()
  76. os.system("rm -rf save")
  77. if __name__ == "__main__":
  78. train_test()
  79. infer()