You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_vocabulary.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import unittest
  2. from collections import Counter
  3. from fastNLP.core.vocabulary import Vocabulary
  4. from fastNLP.core.dataset import DataSet
  5. from fastNLP.core.instance import Instance
  6. text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in",
  7. "works", "well", "in", "most", "cases", "scales", "well"]
  8. counter = Counter(text)
  9. class TestAdd(unittest.TestCase):
  10. def test_add(self):
  11. vocab = Vocabulary(max_size=None, min_freq=None)
  12. for word in text:
  13. vocab.add(word)
  14. self.assertEqual(vocab.word_count, counter)
  15. def test_add_word(self):
  16. vocab = Vocabulary(max_size=None, min_freq=None)
  17. for word in text:
  18. vocab.add_word(word)
  19. self.assertEqual(vocab.word_count, counter)
  20. def test_add_word_lst(self):
  21. vocab = Vocabulary(max_size=None, min_freq=None)
  22. vocab.add_word_lst(text)
  23. self.assertEqual(vocab.word_count, counter)
  24. def test_update(self):
  25. vocab = Vocabulary(max_size=None, min_freq=None)
  26. vocab.update(text)
  27. self.assertEqual(vocab.word_count, counter)
  28. def test_from_dataset(self):
  29. start_char = 65
  30. num_samples = 10
  31. # 0 dim
  32. dataset = DataSet()
  33. for i in range(num_samples):
  34. ins = Instance(char=chr(start_char+i))
  35. dataset.append(ins)
  36. vocab = Vocabulary()
  37. vocab.from_dataset(dataset, field_name='char')
  38. for i in range(num_samples):
  39. self.assertEqual(vocab.to_index(chr(start_char+i)), i+2)
  40. vocab.index_dataset(dataset, field_name='char')
  41. # 1 dim
  42. dataset = DataSet()
  43. for i in range(num_samples):
  44. ins = Instance(char=[chr(start_char+i)]*6)
  45. dataset.append(ins)
  46. vocab = Vocabulary()
  47. vocab.from_dataset(dataset, field_name='char')
  48. for i in range(num_samples):
  49. self.assertEqual(vocab.to_index(chr(start_char+i)), i+2)
  50. vocab.index_dataset(dataset, field_name='char')
  51. # 2 dim
  52. dataset = DataSet()
  53. for i in range(num_samples):
  54. ins = Instance(char=[[chr(start_char+i) for _ in range(6)] for _ in range(6)])
  55. dataset.append(ins)
  56. vocab = Vocabulary()
  57. vocab.from_dataset(dataset, field_name='char')
  58. for i in range(num_samples):
  59. self.assertEqual(vocab.to_index(chr(start_char+i)), i+2)
  60. vocab.index_dataset(dataset, field_name='char')
  61. class TestIndexing(unittest.TestCase):
  62. def test_len(self):
  63. vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None)
  64. vocab.update(text)
  65. self.assertEqual(len(vocab), len(counter))
  66. def test_contains(self):
  67. vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None)
  68. vocab.update(text)
  69. self.assertTrue(text[-1] in vocab)
  70. self.assertFalse("~!@#" in vocab)
  71. self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1]))
  72. self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#"))
  73. def test_index(self):
  74. vocab = Vocabulary(max_size=None, min_freq=None)
  75. vocab.update(text)
  76. res = [vocab[w] for w in set(text)]
  77. self.assertEqual(len(res), len(set(res)))
  78. res = [vocab.to_index(w) for w in set(text)]
  79. self.assertEqual(len(res), len(set(res)))
  80. def test_to_word(self):
  81. vocab = Vocabulary(max_size=None, min_freq=None)
  82. vocab.update(text)
  83. self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]])
  84. def test_iteration(self):
  85. vocab = Vocabulary()
  86. text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in",
  87. "works", "well", "in", "most", "cases", "scales", "well"]
  88. vocab.update(text)
  89. text = set(text)
  90. for word in vocab:
  91. self.assertTrue(word in text)
  92. class TestOther(unittest.TestCase):
  93. def test_additional_update(self):
  94. vocab = Vocabulary(max_size=None, min_freq=None)
  95. vocab.update(text)
  96. _ = vocab["well"]
  97. self.assertEqual(vocab.rebuild, False)
  98. vocab.add("hahaha")
  99. self.assertEqual(vocab.rebuild, True)
  100. _ = vocab["hahaha"]
  101. self.assertEqual(vocab.rebuild, False)
  102. self.assertTrue("hahaha" in vocab)
  103. def test_warning(self):
  104. vocab = Vocabulary(max_size=len(set(text)), min_freq=None)
  105. vocab.update(text)
  106. self.assertEqual(vocab.rebuild, True)
  107. print(len(vocab))
  108. self.assertEqual(vocab.rebuild, False)
  109. vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"])
  110. # this will print a warning
  111. self.assertEqual(vocab.rebuild, True)