You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

readme_example.py 2.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # python: 3.5
  2. # pytorch: 0.4
  3. ################
  4. # Test cross validation.
  5. ################
  6. from fastNLP.loader.preprocess import ClassPreprocess
  7. from fastNLP.core.predictor import ClassificationInfer
  8. from fastNLP.core.trainer import ClassificationTrainer
  9. from fastNLP.loader.dataset_loader import ClassDatasetLoader
  10. from fastNLP.models.base_model import BaseModel
  11. from fastNLP.modules import aggregation
  12. from fastNLP.modules import encoder
  13. from fastNLP.modules import decoder
  14. class ClassificationModel(BaseModel):
  15. """
  16. Simple text classification model based on CNN.
  17. """
  18. def __init__(self, num_classes, vocab_size):
  19. super(ClassificationModel, self).__init__()
  20. self.emb = encoder.Embedding(nums=vocab_size, dims=300)
  21. self.enc = encoder.Conv(
  22. in_channels=300, out_channels=100, kernel_size=3)
  23. self.agg = aggregation.MaxPool()
  24. self.dec = decoder.MLP(100, num_classes=num_classes)
  25. def forward(self, x):
  26. x = self.emb(x) # [N,L] -> [N,L,C]
  27. x = self.enc(x) # [N,L,C_in] -> [N,L,C_out]
  28. x = self.agg(x) # [N,L,C] -> [N,C]
  29. x = self.dec(x) # [N,C] -> [N, N_class]
  30. return x
  31. data_dir = 'data' # directory to save data and model
  32. train_path = 'test/data_for_tests/text_classify.txt' # training set file
  33. # load dataset
  34. ds_loader = ClassDatasetLoader("train", train_path)
  35. data = ds_loader.load()
  36. # pre-process dataset
  37. pre = ClassPreprocess(data, data_dir, cross_val=True, n_fold=5)
  38. # pre = ClassPreprocess(data, data_dir)
  39. n_classes = pre.num_classes
  40. vocab_size = pre.vocab_size
  41. # construct model
  42. model_args = {
  43. 'num_classes': n_classes,
  44. 'vocab_size': vocab_size
  45. }
  46. model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size)
  47. # train model
  48. train_args = {
  49. "epochs": 10,
  50. "batch_size": 50,
  51. "pickle_path": data_dir,
  52. "validate": False,
  53. "save_best_dev": False,
  54. "model_saved_path": None,
  55. "use_cuda": True,
  56. "learn_rate": 1e-3,
  57. "momentum": 0.9}
  58. trainer = ClassificationTrainer(train_args)
  59. # trainer.train(model, ['data_train.pkl', 'data_dev.pkl'])
  60. trainer.cross_validate(model)
  61. # predict using model
  62. data_infer = [x[0] for x in data]
  63. infer = ClassificationInfer(data_dir)
  64. labels_pred = infer.predict(model, data_infer)

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等