You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_seq_label.py 3.0 kB

7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import os
  2. from fastNLP.core.metrics import SeqLabelEvaluator
  3. from fastNLP.core.optimizer import Optimizer
  4. from fastNLP.core.tester import Tester
  5. from fastNLP.core.trainer import Trainer
  6. from fastNLP.core.utils import save_pickle
  7. from fastNLP.core.vocabulary import Vocabulary
  8. from fastNLP.io.config_loader import ConfigLoader, ConfigSection
  9. from fastNLP.io.dataset_loader import TokenizeDataSetLoader
  10. from fastNLP.io.model_loader import ModelLoader
  11. from fastNLP.io.model_saver import ModelSaver
  12. from fastNLP.models.sequence_modeling import SeqLabeling
  13. pickle_path = "./seq_label/"
  14. model_name = "seq_label_model.pkl"
  15. config_dir = "../data_for_tests/config"
  16. data_path = "../data_for_tests/people.txt"
  17. data_infer_path = "../data_for_tests/people_infer.txt"
  18. def test_training():
  19. # Config Loader
  20. trainer_args = ConfigSection()
  21. model_args = ConfigSection()
  22. ConfigLoader().load_config(config_dir, {
  23. "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args})
  24. data_set = TokenizeDataSetLoader().load(data_path)
  25. word_vocab = Vocabulary()
  26. label_vocab = Vocabulary()
  27. data_set.update_vocab(word_seq=word_vocab, label_seq=label_vocab)
  28. data_set.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab)
  29. data_set.set_origin_len("word_seq")
  30. data_set.rename_field("label_seq", "truth").set_target(truth=False)
  31. data_train, data_dev = data_set.split(0.3, shuffle=True)
  32. model_args["vocab_size"] = len(word_vocab)
  33. model_args["num_classes"] = len(label_vocab)
  34. save_pickle(word_vocab, pickle_path, "word2id.pkl")
  35. save_pickle(label_vocab, pickle_path, "label2id.pkl")
  36. trainer = Trainer(
  37. epochs=trainer_args["epochs"],
  38. batch_size=trainer_args["batch_size"],
  39. validate=False,
  40. use_cuda=False,
  41. pickle_path=pickle_path,
  42. save_best_dev=trainer_args["save_best_dev"],
  43. model_name=model_name,
  44. optimizer=Optimizer("SGD", lr=0.01, momentum=0.9),
  45. )
  46. # Model
  47. model = SeqLabeling(model_args)
  48. # Start training
  49. trainer.train(model, data_train, data_dev)
  50. # Saver
  51. saver = ModelSaver(os.path.join(pickle_path, model_name))
  52. saver.save_pytorch(model)
  53. del model, trainer
  54. # Define the same model
  55. model = SeqLabeling(model_args)
  56. # Dump trained parameters into the model
  57. ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))
  58. # Load test configuration
  59. tester_args = ConfigSection()
  60. ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args})
  61. # Tester
  62. tester = Tester(batch_size=4,
  63. use_cuda=False,
  64. pickle_path=pickle_path,
  65. model_name="seq_label_in_test.pkl",
  66. evaluator=SeqLabelEvaluator()
  67. )
  68. # Start testing with validation data
  69. data_dev.set_target(truth=True)
  70. tester.test(model, data_dev)
  71. if __name__ == "__main__":
  72. test_training()