You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_predictor.py 1.7 kB

7 years ago
7 years ago
7 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import os
  2. import unittest
  3. from fastNLP.core.predictor import Predictor
  4. from fastNLP.core.preprocess import save_pickle
  5. from fastNLP.models.sequence_modeling import SeqLabeling
  6. from fastNLP.core.vocabulary import Vocabulary
  7. class TestPredictor(unittest.TestCase):
  8. def test_seq_label(self):
  9. model_args = {
  10. "vocab_size": 10,
  11. "word_emb_dim": 100,
  12. "rnn_hidden_units": 100,
  13. "num_classes": 5
  14. }
  15. infer_data = [
  16. ['a', 'b', 'c', 'd', 'e'],
  17. ['a', '@', 'c', 'd', 'e'],
  18. ['a', 'b', '#', 'd', 'e'],
  19. ['a', 'b', 'c', '?', 'e'],
  20. ['a', 'b', 'c', 'd', '$'],
  21. ['!', 'b', 'c', 'd', 'e']
  22. ]
  23. vocab = Vocabulary()
  24. vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
  25. class_vocab = Vocabulary()
  26. class_vocab.word2idx = {"0":0, "1":1, "2":2, "3":3, "4":4}
  27. os.system("mkdir save")
  28. save_pickle(class_vocab, "./save/", "class2id.pkl")
  29. save_pickle(vocab, "./save/", "word2id.pkl")
  30. model = SeqLabeling(model_args)
  31. predictor = Predictor("./save/", task="seq_label")
  32. results = predictor.predict(network=model, data=infer_data)
  33. self.assertTrue(isinstance(results, list))
  34. self.assertGreater(len(results), 0)
  35. for res in results:
  36. self.assertTrue(isinstance(res, list))
  37. self.assertEqual(len(res), 5)
  38. self.assertTrue(isinstance(res[0], str))
  39. os.system("rm -rf save")
  40. print("pickle path deleted")
  41. class TestPredictor2(unittest.TestCase):
  42. def test_text_classify(self):
  43. # TODO
  44. pass