# Python: 3.5 # encoding: utf-8 import argparse import os import sys sys.path.append("..") from fastNLP.core.predictor import ClassificationInfer from fastNLP.core.trainer import ClassificationTrainer from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.dataset_loader import ClassDatasetLoader from fastNLP.loader.model_loader import ModelLoader from fastNLP.core.preprocess import ClassPreprocess from fastNLP.models.cnn_text_classification import CNNText from fastNLP.saver.model_saver import ModelSaver from fastNLP.core.optimizer import Optimizer from fastNLP.core.loss import Loss parser = argparse.ArgumentParser() parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files") parser.add_argument("-t", "--train", type=str, default="./data_for_tests/text_classify.txt", help="path to the training data") parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file") parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model") args = parser.parse_args() save_dir = args.save train_data_dir = args.train model_name = args.model_name config_dir = args.config def infer(): # load dataset print("Loading data...") ds_loader = ClassDatasetLoader("train", train_data_dir) data = ds_loader.load() unlabeled_data = [x[0] for x in data] # pre-process data pre = ClassPreprocess() data = pre.run(data, pickle_path=save_dir) print("vocabulary size:", pre.vocab_size) print("number of classes:", pre.num_classes) model_args = ConfigSection() # TODO: load from config file model_args["vocab_size"] = pre.vocab_size model_args["num_classes"] = pre.num_classes # ConfigLoader.load_config(config_dir, {"text_class_model": model_args}) # construct model print("Building model...") cnn = CNNText(model_args) # Dump trained parameters into the model ModelLoader.load_pytorch(cnn, os.path.join(save_dir, model_name)) print("model loaded!") infer = ClassificationInfer(pickle_path=save_dir) results = infer.predict(cnn, unlabeled_data) print(results) def train(): train_args, model_args = ConfigSection(), ConfigSection() ConfigLoader.load_config(config_dir, {"text_class": train_args}) # load dataset print("Loading data...") ds_loader = ClassDatasetLoader("train", train_data_dir) data = ds_loader.load() print(data[0]) # pre-process data pre = ClassPreprocess() data_train = pre.run(data, pickle_path=save_dir) print("vocabulary size:", pre.vocab_size) print("number of classes:", pre.num_classes) model_args["num_classes"] = pre.num_classes model_args["vocab_size"] = pre.vocab_size # construct model print("Building model...") model = CNNText(model_args) # ConfigSaver().save_config(config_dir, {"text_class_model": model_args}) # train print("Training...") # 1 # trainer = ClassificationTrainer(train_args) # 2 trainer = ClassificationTrainer(epochs=train_args["epochs"], batch_size=train_args["batch_size"], validate=train_args["validate"], use_cuda=train_args["use_cuda"], pickle_path=save_dir, save_best_dev=train_args["save_best_dev"], model_name=model_name, loss=Loss("cross_entropy"), optimizer=Optimizer("SGD", lr=0.001, momentum=0.9)) trainer.train(model, data_train) print("Training finished!") saver = ModelSaver(os.path.join(save_dir, model_name)) saver.save_pytorch(model) print("Model saved!") if __name__ == "__main__": train() infer()