You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tester.py 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import os
  2. import unittest
  3. from fastNLP.core.dataset import SeqLabelDataSet
  4. from fastNLP.core.metrics import SeqLabelEvaluator
  5. from fastNLP.core.field import TextField, LabelField
  6. from fastNLP.core.instance import Instance
  7. from fastNLP.core.tester import SeqLabelTester
  8. from fastNLP.models.sequence_modeling import SeqLabeling
  9. data_name = "pku_training.utf8"
  10. pickle_path = "data_for_tests"
  11. class TestTester(unittest.TestCase):
  12. def test_case_1(self):
  13. model_args = {
  14. "vocab_size": 10,
  15. "word_emb_dim": 100,
  16. "rnn_hidden_units": 100,
  17. "num_classes": 5
  18. }
  19. valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
  20. "save_loss": True, "batch_size": 2, "pickle_path": "./save/",
  21. "use_cuda": False, "print_every_step": 1, "evaluator": SeqLabelEvaluator()}
  22. train_data = [
  23. [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
  24. [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
  25. [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
  26. [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']],
  27. [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']],
  28. [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']],
  29. ]
  30. vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
  31. label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4}
  32. data_set = SeqLabelDataSet()
  33. for example in train_data:
  34. text, label = example[0], example[1]
  35. x = TextField(text, False)
  36. x_len = LabelField(len(text), is_target=False)
  37. y = TextField(label, is_target=True)
  38. ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len)
  39. data_set.append(ins)
  40. data_set.index_field("word_seq", vocab)
  41. data_set.index_field("truth", label_vocab)
  42. model = SeqLabeling(model_args)
  43. tester = SeqLabelTester(**valid_args)
  44. tester.test(network=model, dev_data=data_set)
  45. # If this can run, everything is OK.
  46. os.system("rm -rf save")
  47. print("pickle path deleted")