You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cws.py 3.5 kB

7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import sys
  2. sys.path.append("..")
  3. from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
  4. from fastNLP.core.trainer import SeqLabelTrainer
  5. from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader
  6. from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle
  7. from fastNLP.saver.model_saver import ModelSaver
  8. from fastNLP.loader.model_loader import ModelLoader
  9. from fastNLP.core.tester import SeqLabelTester
  10. from fastNLP.models.sequence_modeling import SeqLabeling
  11. from fastNLP.core.predictor import Predictor
  12. data_name = "pku_training.utf8"
  13. # cws_data_path = "/home/zyfeng/Desktop/data/pku_training.utf8"
  14. cws_data_path = "data_for_tests/cws_pku_utf_8"
  15. pickle_path = "data_for_tests"
  16. data_infer_path = "data_for_tests/people_infer.txt"
  17. def infer():
  18. # Load infer configuration, the same as test
  19. test_args = ConfigSection()
  20. ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args})
  21. # fetch dictionary size and number of labels from pickle files
  22. word2index = load_pickle(pickle_path, "word2id.pkl")
  23. test_args["vocab_size"] = len(word2index)
  24. index2label = load_pickle(pickle_path, "class2id.pkl")
  25. test_args["num_classes"] = len(index2label)
  26. # Define the same model
  27. model = SeqLabeling(test_args)
  28. # Dump trained parameters into the model
  29. ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl")
  30. print("model loaded!")
  31. # Data Loader
  32. raw_data_loader = BaseLoader(data_infer_path)
  33. infer_data = raw_data_loader.load_lines()
  34. """
  35. Transform strings into list of list of strings.
  36. [
  37. [word_11, word_12, ...],
  38. [word_21, word_22, ...],
  39. ...
  40. ]
  41. In this case, each line in "people_infer.txt" is already a sentence. So load_lines() just splits them.
  42. """
  43. # Inference interface
  44. infer = Predictor(pickle_path)
  45. results = infer.predict(model, infer_data)
  46. print(results)
  47. print("Inference finished!")
  48. def train_test():
  49. # Config Loader
  50. train_args = ConfigSection()
  51. ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS": train_args})
  52. # Data Loader
  53. loader = TokenizeDatasetLoader(cws_data_path)
  54. train_data = loader.load_pku()
  55. # Preprocessor
  56. p = SeqLabelPreprocess()
  57. data_train = p.run(train_data, pickle_path=pickle_path)
  58. train_args["vocab_size"] = p.vocab_size
  59. train_args["num_classes"] = p.num_classes
  60. # Trainer
  61. trainer = SeqLabelTrainer(**train_args.data)
  62. # Model
  63. model = SeqLabeling(train_args)
  64. # Start training
  65. trainer.train(model, data_train)
  66. print("Training finished!")
  67. # Saver
  68. saver = ModelSaver("./data_for_tests/saved_model.pkl")
  69. saver.save_pytorch(model)
  70. print("Model saved!")
  71. del model, trainer, loader
  72. # Define the same model
  73. model = SeqLabeling(train_args)
  74. # Dump trained parameters into the model
  75. ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl")
  76. print("model loaded!")
  77. # Load test configuration
  78. test_args = ConfigSection()
  79. ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args})
  80. # Tester
  81. tester = SeqLabelTester(**test_args.data)
  82. # Start testing
  83. tester.test(model, data_train)
  84. # print test results
  85. print(tester.show_metrics())
  86. print("model tested!")
  87. if __name__ == "__main__":
  88. train_test()
  89. infer()