You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

text_classify.py 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Python: 3.5
  2. # encoding: utf-8
  3. import argparse
  4. import os
  5. import sys
  6. sys.path.append("..")
  7. from fastNLP.core.predictor import ClassificationInfer
  8. from fastNLP.core.trainer import ClassificationTrainer
  9. from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
  10. from fastNLP.loader.dataset_loader import ClassDataSetLoader
  11. from fastNLP.loader.model_loader import ModelLoader
  12. from fastNLP.models.cnn_text_classification import CNNText
  13. from fastNLP.saver.model_saver import ModelSaver
  14. from fastNLP.core.optimizer import Optimizer
  15. from fastNLP.core.loss import Loss
  16. from fastNLP.core.dataset import TextClassifyDataSet
  17. from fastNLP.core.preprocess import save_pickle, load_pickle
  18. parser = argparse.ArgumentParser()
  19. parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files")
  20. parser.add_argument("-t", "--train", type=str, default="../data_for_tests/text_classify.txt",
  21. help="path to the training data")
  22. parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file")
  23. parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model")
  24. args = parser.parse_args()
  25. save_dir = args.save
  26. train_data_dir = args.train
  27. model_name = args.model_name
  28. config_dir = args.config
  29. def infer():
  30. # load dataset
  31. print("Loading data...")
  32. word_vocab = load_pickle(save_dir, "word2id.pkl")
  33. label_vocab = load_pickle(save_dir, "label2id.pkl")
  34. print("vocabulary size:", len(word_vocab))
  35. print("number of classes:", len(label_vocab))
  36. infer_data = TextClassifyDataSet(load_func=ClassDataSetLoader.load)
  37. infer_data.load(train_data_dir, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab})
  38. model_args = ConfigSection()
  39. model_args["vocab_size"] = len(word_vocab)
  40. model_args["num_classes"] = len(label_vocab)
  41. ConfigLoader.load_config(config_dir, {"text_class_model": model_args})
  42. # construct model
  43. print("Building model...")
  44. cnn = CNNText(model_args)
  45. # Dump trained parameters into the model
  46. ModelLoader.load_pytorch(cnn, os.path.join(save_dir, model_name))
  47. print("model loaded!")
  48. infer = ClassificationInfer(pickle_path=save_dir)
  49. results = infer.predict(cnn, infer_data)
  50. print(results)
  51. def train():
  52. train_args, model_args = ConfigSection(), ConfigSection()
  53. ConfigLoader.load_config(config_dir, {"text_class": train_args})
  54. # load dataset
  55. print("Loading data...")
  56. data = TextClassifyDataSet(load_func=ClassDataSetLoader.load)
  57. data.load(train_data_dir)
  58. print("vocabulary size:", len(data.word_vocab))
  59. print("number of classes:", len(data.label_vocab))
  60. save_pickle(data.word_vocab, save_dir, "word2id.pkl")
  61. save_pickle(data.label_vocab, save_dir, "label2id.pkl")
  62. model_args["num_classes"] = len(data.label_vocab)
  63. model_args["vocab_size"] = len(data.word_vocab)
  64. # construct model
  65. print("Building model...")
  66. model = CNNText(model_args)
  67. # train
  68. print("Training...")
  69. trainer = ClassificationTrainer(epochs=train_args["epochs"],
  70. batch_size=train_args["batch_size"],
  71. validate=train_args["validate"],
  72. use_cuda=train_args["use_cuda"],
  73. pickle_path=save_dir,
  74. save_best_dev=train_args["save_best_dev"],
  75. model_name=model_name,
  76. loss=Loss("cross_entropy"),
  77. optimizer=Optimizer("SGD", lr=0.001, momentum=0.9))
  78. trainer.train(model, data)
  79. print("Training finished!")
  80. saver = ModelSaver(os.path.join(save_dir, model_name))
  81. saver.save_pytorch(model)
  82. print("Model saved!")
  83. if __name__ == "__main__":
  84. train()
  85. infer()