You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer.py 2.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import os
  2. import unittest
  3. from fastNLP.core.dataset import DataSet
  4. from fastNLP.core.field import TextField, LabelField
  5. from fastNLP.core.instance import Instance
  6. from fastNLP.core.loss import Loss
  7. from fastNLP.core.metrics import SeqLabelEvaluator
  8. from fastNLP.core.optimizer import Optimizer
  9. from fastNLP.core.trainer import Trainer
  10. from fastNLP.models.sequence_modeling import SeqLabeling
  11. class TestTrainer(unittest.TestCase):
  12. def test_case_1(self):
  13. args = {"epochs": 3, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/",
  14. "save_best_dev": True, "model_name": "default_model_name.pkl",
  15. "loss": Loss("cross_entropy"),
  16. "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
  17. "vocab_size": 10,
  18. "word_emb_dim": 100,
  19. "rnn_hidden_units": 100,
  20. "num_classes": 5,
  21. "evaluator": SeqLabelEvaluator()
  22. }
  23. trainer = Trainer(**args)
  24. train_data = [
  25. [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
  26. [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
  27. [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
  28. [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
  29. [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
  30. [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
  31. ]
  32. vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
  33. label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
  34. data_set = DataSet()
  35. for example in train_data:
  36. text, label = example[0], example[1]
  37. x = TextField(text, False)
  38. x_len = LabelField(len(text), is_target=False)
  39. y = TextField(label, is_target=False)
  40. ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len)
  41. data_set.append(ins)
  42. data_set.index_field("word_seq", vocab)
  43. data_set.index_field("truth", label_vocab)
  44. model = SeqLabeling(args)
  45. trainer.train(network=model, train_data=data_set, dev_data=data_set)
  46. # If this can run, everything is OK.
  47. os.system("rm -rf save")
  48. print("pickle path deleted")