You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

text_classify.py 2.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Python: 3.5
  2. # encoding: utf-8
  3. import os
  4. import sys
  5. sys.path.append("..")
  6. from fastNLP.core.predictor import ClassificationInfer
  7. from fastNLP.core.trainer import ClassificationTrainer
  8. from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
  9. from fastNLP.loader.dataset_loader import ClassDatasetLoader
  10. from fastNLP.loader.model_loader import ModelLoader
  11. from fastNLP.core.preprocess import ClassPreprocess
  12. from fastNLP.models.cnn_text_classification import CNNText
  13. from fastNLP.saver.model_saver import ModelSaver
  14. save_path = "./test_classification/"
  15. data_dir = "./data_for_tests/"
  16. train_file = 'text_classify.txt'
  17. model_name = "model_class.pkl"
  18. def infer():
  19. # load dataset
  20. print("Loading data...")
  21. ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file))
  22. data = ds_loader.load()
  23. unlabeled_data = [x[0] for x in data]
  24. # pre-process data
  25. pre = ClassPreprocess()
  26. vocab_size, n_classes = pre.run(data, pickle_path=save_path)
  27. print("vocabulary size:", vocab_size)
  28. print("number of classes:", n_classes)
  29. model_args = ConfigSection()
  30. ConfigLoader.load_config("data_for_tests/config", {"text_class_model": model_args})
  31. # construct model
  32. print("Building model...")
  33. cnn = CNNText(model_args)
  34. # Dump trained parameters into the model
  35. ModelLoader.load_pytorch(cnn, "./data_for_tests/saved_model.pkl")
  36. print("model loaded!")
  37. infer = ClassificationInfer(data_dir)
  38. results = infer.predict(cnn, unlabeled_data)
  39. print(results)
  40. def train():
  41. train_args, model_args = ConfigSection(), ConfigSection()
  42. ConfigLoader.load_config("data_for_tests/config", {"text_class": train_args, "text_class_model": model_args})
  43. # load dataset
  44. print("Loading data...")
  45. ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file))
  46. data = ds_loader.load()
  47. print(data[0])
  48. # pre-process data
  49. pre = ClassPreprocess()
  50. data_train = pre.run(data, pickle_path=save_path)
  51. print("vocabulary size:", pre.vocab_size)
  52. print("number of classes:", pre.num_classes)
  53. # construct model
  54. print("Building model...")
  55. model = CNNText(model_args)
  56. # train
  57. print("Training...")
  58. trainer = ClassificationTrainer(train_args)
  59. trainer.train(model, data_train)
  60. print("Training finished!")
  61. saver = ModelSaver("./data_for_tests/saved_model.pkl")
  62. saver.save_pytorch(model)
  63. print("Model saved!")
  64. if __name__ == "__main__":
  65. train()
  66. # infer()

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等