You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_seq_label.py 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import os
  2. from fastNLP.core.optimizer import Optimizer
  3. from fastNLP.core.preprocess import SeqLabelPreprocess
  4. from fastNLP.core.tester import SeqLabelTester
  5. from fastNLP.core.trainer import SeqLabelTrainer
  6. from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
  7. from fastNLP.loader.dataset_loader import POSDatasetLoader
  8. from fastNLP.loader.model_loader import ModelLoader
  9. from fastNLP.models.sequence_modeling import SeqLabeling
  10. from fastNLP.saver.model_saver import ModelSaver
  11. pickle_path = "./seq_label/"
  12. model_name = "seq_label_model.pkl"
  13. config_dir = "test/data_for_tests/config"
  14. data_path = "test/data_for_tests/people.txt"
  15. data_infer_path = "test/data_for_tests/people_infer.txt"
  16. def test_training():
  17. # Config Loader
  18. trainer_args = ConfigSection()
  19. model_args = ConfigSection()
  20. ConfigLoader("_").load_config(config_dir, {
  21. "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args})
  22. # Data Loader
  23. pos_loader = POSDatasetLoader(data_path)
  24. train_data = pos_loader.load_lines()
  25. # Preprocessor
  26. p = SeqLabelPreprocess()
  27. data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5)
  28. model_args["vocab_size"] = p.vocab_size
  29. model_args["num_classes"] = p.num_classes
  30. trainer = SeqLabelTrainer(
  31. epochs=trainer_args["epochs"],
  32. batch_size=trainer_args["batch_size"],
  33. validate=False,
  34. use_cuda=False,
  35. pickle_path=pickle_path,
  36. save_best_dev=trainer_args["save_best_dev"],
  37. model_name=model_name,
  38. optimizer=Optimizer("SGD", lr=0.01, momentum=0.9),
  39. )
  40. # Model
  41. model = SeqLabeling(model_args)
  42. # Start training
  43. trainer.train(model, data_train, data_dev)
  44. # Saver
  45. saver = ModelSaver(os.path.join(pickle_path, model_name))
  46. saver.save_pytorch(model)
  47. del model, trainer, pos_loader
  48. # Define the same model
  49. model = SeqLabeling(model_args)
  50. # Dump trained parameters into the model
  51. ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name))
  52. # Load test configuration
  53. tester_args = ConfigSection()
  54. ConfigLoader("config.cfg").load_config(config_dir, {"test_seq_label_tester": tester_args})
  55. # Tester
  56. tester = SeqLabelTester(save_output=False,
  57. save_loss=True,
  58. save_best_dev=False,
  59. batch_size=4,
  60. use_cuda=False,
  61. pickle_path=pickle_path,
  62. model_name="seq_label_in_test.pkl",
  63. print_every_step=1
  64. )
  65. # Start testing with validation data
  66. tester.test(model, data_dev)
  67. loss, accuracy = tester.metrics
  68. assert 0 < accuracy < 1