You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tester.py 2.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import os
  2. import unittest
  3. from fastNLP.core.dataset import DataSet
  4. from fastNLP.core.field import TextField
  5. from fastNLP.core.instance import Instance
  6. from fastNLP.core.tester import SeqLabelTester
  7. from fastNLP.models.sequence_modeling import SeqLabeling
  8. data_name = "pku_training.utf8"
  9. pickle_path = "data_for_tests"
  10. class TestTester(unittest.TestCase):
  11. def test_case_1(self):
  12. model_args = {
  13. "vocab_size": 10,
  14. "word_emb_dim": 100,
  15. "rnn_hidden_units": 100,
  16. "num_classes": 5
  17. }
  18. valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
  19. "save_loss": True, "batch_size": 2, "pickle_path": "./save/",
  20. "use_cuda": False, "print_every_step": 1}
  21. train_data = [
  22. [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
  23. [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
  24. [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
  25. [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
  26. [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
  27. [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
  28. ]
  29. vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
  30. label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
  31. data_set = DataSet()
  32. for example in train_data:
  33. text, label = example[0], example[1]
  34. x = TextField(text, False)
  35. y = TextField(label, is_target=True)
  36. ins = Instance(word_seq=x, label_seq=y)
  37. data_set.append(ins)
  38. data_set.index_field("word_seq", vocab)
  39. data_set.index_field("label_seq", label_vocab)
  40. model = SeqLabeling(model_args)
  41. tester = SeqLabelTester(**valid_args)
  42. tester.test(network=model, dev_data=data_set)
  43. # If this can run, everything is OK.
  44. os.system("rm -rf save")
  45. print("pickle path deleted")