You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quickstart.rst 2.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. ==========
  2. Quickstart
  3. ==========
  4. Example
  5. -------
  6. Basic Usage
  7. ~~~~~~~~~~~
  8. A typical fastNLP routine is composed of four phases: loading dataset,
  9. pre-processing data, constructing model and training model.
  10. .. code:: python
  11. from fastNLP.models.base_model import BaseModel
  12. from fastNLP.modules import encoder
  13. from fastNLP.modules import aggregation
  14. from fastNLP.modules import decoder
  15. from fastNLP.loader.dataset_loader import ClassDatasetLoader
  16. from fastNLP.loader.preprocess import ClassPreprocess
  17. from fastNLP.core.trainer import ClassificationTrainer
  18. from fastNLP.core.inference import ClassificationInfer
  19. class ClassificationModel(BaseModel):
  20. """
  21. Simple text classification model based on CNN.
  22. """
  23. def __init__(self, num_classes, vocab_size):
  24. super(ClassificationModel, self).__init__()
  25. self.emb = encoder.Embedding(nums=vocab_size, dims=300)
  26. self.enc = encoder.Conv(
  27. in_channels=300, out_channels=100, kernel_size=3)
  28. self.agg = aggregation.MaxPool()
  29. self.dec = decoder.MLP(100, num_classes=num_classes)
  30. def forward(self, x):
  31. x = self.emb(x) # [N,L] -> [N,L,C]
  32. x = self.enc(x) # [N,L,C_in] -> [N,L,C_out]
  33. x = self.agg(x) # [N,L,C] -> [N,C]
  34. x = self.dec(x) # [N,C] -> [N, N_class]
  35. return x
  36. data_dir = 'data' # directory to save data and model
  37. train_path = 'test/data_for_tests/text_classify.txt' # training set file
  38. # load dataset
  39. ds_loader = ClassDatasetLoader("train", train_path)
  40. data = ds_loader.load()
  41. # pre-process dataset
  42. pre = ClassPreprocess(data_dir)
  43. vocab_size, n_classes = pre.process(data, "data_train.pkl")
  44. # construct model
  45. model_args = {
  46. 'num_classes': n_classes,
  47. 'vocab_size': vocab_size
  48. }
  49. model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size)
  50. # train model
  51. train_args = {
  52. "epochs": 20,
  53. "batch_size": 50,
  54. "pickle_path": data_dir,
  55. "validate": False,
  56. "save_best_dev": False,
  57. "model_saved_path": None,
  58. "use_cuda": True,
  59. "learn_rate": 1e-3,
  60. "momentum": 0.9}
  61. trainer = ClassificationTrainer(train_args)
  62. trainer.train(model)
  63. # predict using model
  64. seqs = [x[0] for x in data]
  65. infer = ClassificationInfer(data_dir)
  66. labels_pred = infer.predict(model, seqs)

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等