You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_vocab.py 1.3 kB

1234567891011121314151617181920212223242526272829303132333435
  1. import os
  2. import sys
  3. sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
  4. import unittest
  5. from fastNLP.data.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX
  6. class TestVocabulary(unittest.TestCase):
  7. def test_vocab(self):
  8. import _pickle as pickle
  9. import os
  10. vocab = Vocabulary()
  11. filename = 'vocab'
  12. vocab.update(filename)
  13. vocab.update([filename, ['a'], [['b']], ['c']])
  14. idx = vocab[filename]
  15. before_pic = (vocab.to_word(idx), vocab[filename])
  16. with open(filename, 'wb') as f:
  17. pickle.dump(vocab, f)
  18. with open(filename, 'rb') as f:
  19. vocab = pickle.load(f)
  20. os.remove(filename)
  21. vocab.build_reverse_vocab()
  22. after_pic = (vocab.to_word(idx), vocab[filename])
  23. TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8}
  24. TRUE_DICT.update(DEFAULT_WORD_TO_INDEX)
  25. TRUE_IDXDICT = {0: '<pad>', 1: '<unk>', 2: '<reserved-2>', 3: '<reserved-3>', 4: '<reserved-4>', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'}
  26. self.assertEqual(before_pic, after_pic)
  27. self.assertDictEqual(TRUE_DICT, vocab.word2idx)
  28. self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word)
  29. if __name__ == '__main__':
  30. unittest.main()