You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_predictor.py 2.9 kB

7 years ago
7 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import os
  2. import unittest
  3. from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet
  4. from fastNLP.core.predictor import Predictor
  5. from fastNLP.core.preprocess import save_pickle
  6. from fastNLP.core.vocabulary import Vocabulary
  7. from fastNLP.loader.base_loader import BaseLoader
  8. from fastNLP.models.cnn_text_classification import CNNText
  9. from fastNLP.models.sequence_modeling import SeqLabeling
  10. class TestPredictor(unittest.TestCase):
  11. def test_seq_label(self):
  12. model_args = {
  13. "vocab_size": 10,
  14. "word_emb_dim": 100,
  15. "rnn_hidden_units": 100,
  16. "num_classes": 5
  17. }
  18. infer_data = [
  19. ['a', 'b', 'c', 'd', 'e'],
  20. ['a', '@', 'c', 'd', 'e'],
  21. ['a', 'b', '#', 'd', 'e'],
  22. ['a', 'b', 'c', '?', 'e'],
  23. ['a', 'b', 'c', 'd', '$'],
  24. ['!', 'b', 'c', 'd', 'e']
  25. ]
  26. vocab = Vocabulary()
  27. vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
  28. class_vocab = Vocabulary()
  29. class_vocab.word2idx = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4}
  30. os.system("mkdir save")
  31. save_pickle(class_vocab, "./save/", "label2id.pkl")
  32. save_pickle(vocab, "./save/", "word2id.pkl")
  33. model = CNNText(model_args)
  34. import fastNLP.core.predictor as pre
  35. predictor = Predictor("./save/", pre.text_classify_post_processor)
  36. # Load infer data
  37. infer_data_set = TextClassifyDataSet(load_func=BaseLoader.load)
  38. infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx})
  39. results = predictor.predict(network=model, data=infer_data_set)
  40. self.assertTrue(isinstance(results, list))
  41. self.assertGreater(len(results), 0)
  42. self.assertEqual(len(results), len(infer_data))
  43. for res in results:
  44. self.assertTrue(isinstance(res, str))
  45. self.assertTrue(res in class_vocab.word2idx)
  46. del model, predictor, infer_data_set
  47. model = SeqLabeling(model_args)
  48. predictor = Predictor("./save/", pre.seq_label_post_processor)
  49. infer_data_set = SeqLabelDataSet(load_func=BaseLoader.load)
  50. infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx})
  51. results = predictor.predict(network=model, data=infer_data_set)
  52. self.assertTrue(isinstance(results, list))
  53. self.assertEqual(len(results), len(infer_data))
  54. for i in range(len(infer_data)):
  55. res = results[i]
  56. self.assertTrue(isinstance(res, list))
  57. self.assertEqual(len(res), len(infer_data[i]))
  58. os.system("rm -rf save")
  59. print("pickle path deleted")
  60. class TestPredictor2(unittest.TestCase):
  61. def test_text_classify(self):
  62. # TODO
  63. pass