You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_vocabulary.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import unittest
  2. from collections import Counter
  3. from fastNLP.core.vocabulary import Vocabulary
  4. text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in",
  5. "works", "well", "in", "most", "cases", "scales", "well"]
  6. counter = Counter(text)
  7. class TestAdd(unittest.TestCase):
  8. def test_add(self):
  9. vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
  10. for word in text:
  11. vocab.add(word)
  12. self.assertEqual(vocab.word_count, counter)
  13. def test_add_word(self):
  14. vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
  15. for word in text:
  16. vocab.add_word(word)
  17. self.assertEqual(vocab.word_count, counter)
  18. def test_add_word_lst(self):
  19. vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
  20. vocab.add_word_lst(text)
  21. self.assertEqual(vocab.word_count, counter)
  22. def test_update(self):
  23. vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
  24. vocab.update(text)
  25. self.assertEqual(vocab.word_count, counter)
  26. class TestIndexing(unittest.TestCase):
  27. def test_len(self):
  28. vocab = Vocabulary(need_default=False, max_size=None, min_freq=None)
  29. vocab.update(text)
  30. self.assertEqual(len(vocab), len(counter))
  31. def test_contains(self):
  32. vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
  33. vocab.update(text)
  34. self.assertTrue(text[-1] in vocab)
  35. self.assertFalse("~!@#" in vocab)
  36. self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1]))
  37. self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#"))
  38. def test_index(self):
  39. vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
  40. vocab.update(text)
  41. res = [vocab[w] for w in set(text)]
  42. self.assertEqual(len(res), len(set(res)))
  43. res = [vocab.to_index(w) for w in set(text)]
  44. self.assertEqual(len(res), len(set(res)))
  45. def test_to_word(self):
  46. vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
  47. vocab.update(text)
  48. self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]])