You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

readme_example.py 2.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from fastNLP.core.loss import Loss
  2. from fastNLP.core.optimizer import Optimizer
  3. from fastNLP.core.predictor import ClassificationInfer
  4. from fastNLP.core.preprocess import ClassPreprocess
  5. from fastNLP.core.trainer import ClassificationTrainer
  6. from fastNLP.loader.dataset_loader import ClassDataSetLoader
  7. from fastNLP.models.base_model import BaseModel
  8. from fastNLP.modules import aggregator
  9. from fastNLP.modules import decoder
  10. from fastNLP.modules import encoder
  11. class ClassificationModel(BaseModel):
  12. """
  13. Simple text classification model based on CNN.
  14. """
  15. def __init__(self, num_classes, vocab_size):
  16. super(ClassificationModel, self).__init__()
  17. self.emb = encoder.Embedding(nums=vocab_size, dims=300)
  18. self.enc = encoder.Conv(
  19. in_channels=300, out_channels=100, kernel_size=3)
  20. self.agg = aggregator.MaxPool()
  21. self.dec = decoder.MLP(size_layer=[100, num_classes])
  22. def forward(self, x):
  23. x = self.emb(x) # [N,L] -> [N,L,C]
  24. x = self.enc(x) # [N,L,C_in] -> [N,L,C_out]
  25. x = self.agg(x) # [N,L,C] -> [N,C]
  26. x = self.dec(x) # [N,C] -> [N, N_class]
  27. return x
  28. data_dir = 'save/' # directory to save data and model
  29. train_path = './data_for_tests/text_classify.txt' # training set file
  30. # load dataset
  31. ds_loader = ClassDataSetLoader()
  32. data = ds_loader.load()
  33. # pre-process dataset
  34. pre = ClassPreprocess()
  35. train_set, dev_set = pre.run(data, train_dev_split=0.3, pickle_path=data_dir)
  36. n_classes, vocab_size = pre.num_classes, pre.vocab_size
  37. # construct model
  38. model_args = {
  39. 'num_classes': n_classes,
  40. 'vocab_size': vocab_size
  41. }
  42. model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size)
  43. # construct trainer
  44. train_args = {
  45. "epochs": 3,
  46. "batch_size": 16,
  47. "pickle_path": data_dir,
  48. "validate": False,
  49. "save_best_dev": False,
  50. "model_saved_path": None,
  51. "use_cuda": True,
  52. "loss": Loss("cross_entropy"),
  53. "optimizer": Optimizer("Adam", lr=0.001)
  54. }
  55. trainer = ClassificationTrainer(**train_args)
  56. # start training
  57. trainer.train(model, train_data=train_set, dev_data=dev_set)
  58. # predict using model
  59. data_infer = [x[0] for x in data]
  60. infer = ClassificationInfer(data_dir)
  61. labels_pred = infer.predict(model.cpu(), data_infer)
  62. print(labels_pred)