You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

text_classify.py 2.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Python: 3.5
  2. # encoding: utf-8
  3. import os
  4. from fastNLP.core.inference import ClassificationInfer
  5. from fastNLP.core.trainer import ClassificationTrainer
  6. from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
  7. from fastNLP.loader.dataset_loader import ClassDatasetLoader
  8. from fastNLP.loader.model_loader import ModelLoader
  9. from fastNLP.loader.preprocess import ClassPreprocess
  10. from fastNLP.models.cnn_text_classification import CNNText
  11. from fastNLP.saver.model_saver import ModelSaver
  12. data_dir = "./data_for_tests/"
  13. train_file = 'text_classify.txt'
  14. model_name = "model_class.pkl"
  15. def infer():
  16. # load dataset
  17. print("Loading data...")
  18. ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file))
  19. data = ds_loader.load()
  20. unlabeled_data = [x[0] for x in data]
  21. # pre-process data
  22. pre = ClassPreprocess(data_dir)
  23. vocab_size, n_classes = pre.process(data, "data_train.pkl")
  24. print("vocabulary size:", vocab_size)
  25. print("number of classes:", n_classes)
  26. model_args = ConfigSection()
  27. ConfigLoader.load_config("data_for_tests/config", {"text_class_model": model_args})
  28. # construct model
  29. print("Building model...")
  30. cnn = CNNText(model_args)
  31. # Dump trained parameters into the model
  32. ModelLoader.load_pytorch(cnn, "./data_for_tests/saved_model.pkl")
  33. print("model loaded!")
  34. infer = ClassificationInfer(data_dir)
  35. results = infer.predict(cnn, unlabeled_data)
  36. print(results)
  37. def train():
  38. train_args, model_args = ConfigSection(), ConfigSection()
  39. ConfigLoader.load_config("data_for_tests/config", {"text_class": train_args, "text_class_model": model_args})
  40. # load dataset
  41. print("Loading data...")
  42. ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file))
  43. data = ds_loader.load()
  44. print(data[0])
  45. # pre-process data
  46. pre = ClassPreprocess(data_dir)
  47. vocab_size, n_classes = pre.process(data, "data_train.pkl")
  48. print("vocabulary size:", vocab_size)
  49. print("number of classes:", n_classes)
  50. # construct model
  51. print("Building model...")
  52. cnn = CNNText(model_args)
  53. # train
  54. print("Training...")
  55. trainer = ClassificationTrainer(train_args)
  56. trainer.train(cnn)
  57. print("Training finished!")
  58. saver = ModelSaver("./data_for_tests/saved_model.pkl")
  59. saver.save_pytorch(cnn)
  60. print("Model saved!")
  61. if __name__ == "__main__":
  62. # train()
  63. infer()

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等