You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer.py 1.3 kB

123456789101112131415161718192021222324252627282930313233
  1. import os
  2. import torch.nn as nn
  3. import unittest
  4. from fastNLP.core.trainer import SeqLabelTrainer
  5. from fastNLP.core.loss import Loss
  6. from fastNLP.core.optimizer import Optimizer
  7. from fastNLP.models.sequence_modeling import SeqLabeling
  8. class TestTrainer(unittest.TestCase):
  9. def test_case_1(self):
  10. args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/",
  11. "save_best_dev": True, "model_name": "default_model_name.pkl",
  12. "loss": Loss(None),
  13. "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
  14. "vocab_size": 20,
  15. "word_emb_dim": 100,
  16. "rnn_hidden_units": 100,
  17. "num_classes": 3
  18. }
  19. trainer = SeqLabelTrainer()
  20. train_data = [
  21. [[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]],
  22. [[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]],
  23. [[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]],
  24. [[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]],
  25. [[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]],
  26. [[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]],
  27. ]
  28. dev_data = train_data
  29. model = SeqLabeling(args)
  30. trainer.train(network=model, train_data=train_data, dev_data=dev_data)

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等