You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

text_classify.py 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Python: 3.5
  2. # encoding: utf-8
  3. import argparse
  4. import os
  5. import sys
  6. sys.path.append("..")
  7. from fastNLP.core.predictor import ClassificationInfer
  8. from fastNLP.core.trainer import ClassificationTrainer
  9. from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
  10. from fastNLP.loader.dataset_loader import ClassDatasetLoader
  11. from fastNLP.loader.model_loader import ModelLoader
  12. from fastNLP.core.preprocess import ClassPreprocess
  13. from fastNLP.models.cnn_text_classification import CNNText
  14. from fastNLP.saver.model_saver import ModelSaver
  15. from fastNLP.core.optimizer import Optimizer
  16. from fastNLP.core.loss import Loss
  17. parser = argparse.ArgumentParser()
  18. parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files")
  19. parser.add_argument("-t", "--train", type=str, default="./data_for_tests/text_classify.txt",
  20. help="path to the training data")
  21. parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file")
  22. parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model")
  23. args = parser.parse_args()
  24. save_dir = args.save
  25. train_data_dir = args.train
  26. model_name = args.model_name
  27. config_dir = args.config
  28. def infer():
  29. # load dataset
  30. print("Loading data...")
  31. ds_loader = ClassDatasetLoader("train", train_data_dir)
  32. data = ds_loader.load()
  33. unlabeled_data = [x[0] for x in data]
  34. # pre-process data
  35. pre = ClassPreprocess()
  36. data = pre.run(data, pickle_path=save_dir)
  37. print("vocabulary size:", pre.vocab_size)
  38. print("number of classes:", pre.num_classes)
  39. model_args = ConfigSection()
  40. # TODO: load from config file
  41. model_args["vocab_size"] = pre.vocab_size
  42. model_args["num_classes"] = pre.num_classes
  43. # ConfigLoader.load_config(config_dir, {"text_class_model": model_args})
  44. # construct model
  45. print("Building model...")
  46. cnn = CNNText(model_args)
  47. # Dump trained parameters into the model
  48. ModelLoader.load_pytorch(cnn, os.path.join(save_dir, model_name))
  49. print("model loaded!")
  50. infer = ClassificationInfer(pickle_path=save_dir)
  51. results = infer.predict(cnn, unlabeled_data)
  52. print(results)
  53. def train():
  54. train_args, model_args = ConfigSection(), ConfigSection()
  55. ConfigLoader.load_config(config_dir, {"text_class": train_args})
  56. # load dataset
  57. print("Loading data...")
  58. ds_loader = ClassDatasetLoader("train", train_data_dir)
  59. data = ds_loader.load()
  60. print(data[0])
  61. # pre-process data
  62. pre = ClassPreprocess()
  63. data_train = pre.run(data, pickle_path=save_dir)
  64. print("vocabulary size:", pre.vocab_size)
  65. print("number of classes:", pre.num_classes)
  66. model_args["num_classes"] = pre.num_classes
  67. model_args["vocab_size"] = pre.vocab_size
  68. # construct model
  69. print("Building model...")
  70. model = CNNText(model_args)
  71. # ConfigSaver().save_config(config_dir, {"text_class_model": model_args})
  72. # train
  73. print("Training...")
  74. # 1
  75. # trainer = ClassificationTrainer(train_args)
  76. # 2
  77. trainer = ClassificationTrainer(epochs=train_args["epochs"],
  78. batch_size=train_args["batch_size"],
  79. validate=train_args["validate"],
  80. use_cuda=train_args["use_cuda"],
  81. pickle_path=save_dir,
  82. save_best_dev=train_args["save_best_dev"],
  83. model_name=model_name,
  84. loss=Loss("cross_entropy"),
  85. optimizer=Optimizer("SGD", lr=0.001, momentum=0.9))
  86. trainer.train(model, data_train)
  87. print("Training finished!")
  88. saver = ModelSaver(os.path.join(save_dir, model_name))
  89. saver.save_pytorch(model)
  90. print("Model saved!")
  91. if __name__ == "__main__":
  92. train()
  93. infer()

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等