You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

text_classify.py 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Python: 3.5
  2. # encoding: utf-8
  3. import os
  4. import sys
  5. sys.path.append("..")
  6. from fastNLP.core.predictor import ClassificationInfer
  7. from fastNLP.core.trainer import ClassificationTrainer
  8. from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
  9. from fastNLP.loader.dataset_loader import ClassDatasetLoader
  10. from fastNLP.loader.model_loader import ModelLoader
  11. from fastNLP.loader.preprocess import ClassPreprocess
  12. from fastNLP.models.cnn_text_classification import CNNText
  13. from fastNLP.saver.model_saver import ModelSaver
  14. data_dir = "./data_for_tests/"
  15. train_file = 'text_classify.txt'
  16. model_name = "model_class.pkl"
  17. def infer():
  18. # load dataset
  19. print("Loading data...")
  20. ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file))
  21. data = ds_loader.load()
  22. unlabeled_data = [x[0] for x in data]
  23. # pre-process data
  24. pre = ClassPreprocess(data_dir)
  25. vocab_size, n_classes = pre.process(data, "data_train.pkl")
  26. print("vocabulary size:", vocab_size)
  27. print("number of classes:", n_classes)
  28. model_args = ConfigSection()
  29. ConfigLoader.load_config("data_for_tests/config", {"text_class_model": model_args})
  30. # construct model
  31. print("Building model...")
  32. cnn = CNNText(model_args)
  33. # Dump trained parameters into the model
  34. ModelLoader.load_pytorch(cnn, "./data_for_tests/saved_model.pkl")
  35. print("model loaded!")
  36. infer = ClassificationInfer(data_dir)
  37. results = infer.predict(cnn, unlabeled_data)
  38. print(results)
  39. def train():
  40. train_args, model_args = ConfigSection(), ConfigSection()
  41. ConfigLoader.load_config("data_for_tests/config", {"text_class": train_args, "text_class_model": model_args})
  42. # load dataset
  43. print("Loading data...")
  44. ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file))
  45. data = ds_loader.load()
  46. print(data[0])
  47. # pre-process data
  48. pre = ClassPreprocess(data_dir)
  49. vocab_size, n_classes = pre.process(data, "data_train.pkl")
  50. print("vocabulary size:", vocab_size)
  51. print("number of classes:", n_classes)
  52. # construct model
  53. print("Building model...")
  54. cnn = CNNText(model_args)
  55. # train
  56. print("Training...")
  57. trainer = ClassificationTrainer(train_args)
  58. trainer.train(cnn)
  59. print("Training finished!")
  60. saver = ModelSaver("./data_for_tests/saved_model.pkl")
  61. saver.save_pytorch(cnn)
  62. print("Model saved!")
  63. if __name__ == "__main__":
  64. # train()
  65. infer()

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等