import os from fastNLP.core.metrics import SeqLabelEvaluator from fastNLP.core.optimizer import Optimizer from fastNLP.core.tester import SeqLabelTester from fastNLP.core.trainer import SeqLabelTrainer from fastNLP.core.utils import save_pickle from fastNLP.core.vocabulary import Vocabulary from fastNLP.io.config_loader import ConfigLoader, ConfigSection from fastNLP.io.dataset_loader import TokenizeDataSetLoader from fastNLP.io.model_loader import ModelLoader from fastNLP.io.model_saver import ModelSaver from fastNLP.models.sequence_modeling import SeqLabeling pickle_path = "./seq_label/" model_name = "seq_label_model.pkl" config_dir = "../data_for_tests/config" data_path = "../data_for_tests/people.txt" data_infer_path = "../data_for_tests/people_infer.txt" def test_training(): # Config Loader trainer_args = ConfigSection() model_args = ConfigSection() ConfigLoader().load_config(config_dir, { "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) data_set = TokenizeDataSetLoader().load(data_path) word_vocab = Vocabulary() label_vocab = Vocabulary() data_set.update_vocab(word_seq=word_vocab, label_seq=label_vocab) data_set.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab) data_set.set_origin_len("word_seq") data_set.rename_field("label_seq", "truth").set_target(truth=False) data_train, data_dev = data_set.split(0.3, shuffle=True) model_args["vocab_size"] = len(word_vocab) model_args["num_classes"] = len(label_vocab) save_pickle(word_vocab, pickle_path, "word2id.pkl") save_pickle(label_vocab, pickle_path, "label2id.pkl") trainer = SeqLabelTrainer( epochs=trainer_args["epochs"], batch_size=trainer_args["batch_size"], validate=False, use_cuda=False, pickle_path=pickle_path, save_best_dev=trainer_args["save_best_dev"], model_name=model_name, optimizer=Optimizer("SGD", lr=0.01, momentum=0.9), ) # Model model = SeqLabeling(model_args) # Start training trainer.train(model, data_train, data_dev) # Saver saver = ModelSaver(os.path.join(pickle_path, model_name)) saver.save_pytorch(model) del model, trainer # Define the same model model = SeqLabeling(model_args) # Dump trained parameters into the model ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) # Load test configuration tester_args = ConfigSection() ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args}) # Tester tester = SeqLabelTester(batch_size=4, use_cuda=False, pickle_path=pickle_path, model_name="seq_label_in_test.pkl", evaluator=SeqLabelEvaluator() ) # Start testing with validation data data_dev.set_target(truth=True) tester.test(model, data_dev) if __name__ == "__main__": test_training()