You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_static_embedding.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. import unittest
  2. from fastNLP.embeddings import StaticEmbedding
  3. from fastNLP import Vocabulary
  4. import torch
  5. import os
  6. class TestLoad(unittest.TestCase):
  7. def test_norm1(self):
  8. # 测试只对可以找到的norm
  9. vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
  10. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  11. 'glove.6B.50d_test.txt',
  12. only_norm_found_vector=True)
  13. self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1)
  14. self.assertNotEqual(torch.norm(embed(torch.LongTensor([[4]]))).item(), 1)
  15. def test_norm2(self):
  16. # 测试对所有都norm
  17. vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
  18. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  19. 'glove.6B.50d_test.txt',
  20. normalize=True)
  21. self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1)
  22. self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1)
  23. def test_dropword(self):
  24. # 测试是否可以通过drop word
  25. vocab = Vocabulary().add_word_lst([chr(i) for i in range(1, 200)])
  26. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10, dropout=0.1, word_dropout=0.4)
  27. for i in range(10):
  28. length = torch.randint(1, 50, (1,)).item()
  29. batch = torch.randint(1, 4, (1,)).item()
  30. words = torch.randint(1, 200, (batch, length)).long()
  31. embed(words)
  32. def test_only_use_pretrain_word(self):
  33. def check_word_unk(words, vocab, embed):
  34. for word in words:
  35. self.assertListEqual(embed(torch.LongTensor([vocab.to_index(word)])).tolist()[0],
  36. embed(torch.LongTensor([1])).tolist()[0])
  37. def check_vector_equal(words, vocab, embed, embed_dict, lower=False):
  38. for word in words:
  39. index = vocab.to_index(word)
  40. v1 = embed(torch.LongTensor([index])).tolist()[0]
  41. if lower:
  42. word = word.lower()
  43. v2 = embed_dict[word]
  44. for v1i, v2i in zip(v1, v2):
  45. self.assertAlmostEqual(v1i, v2i, places=4)
  46. embed_dict = read_static_embed('test/data_for_tests/embedding/small_static_embedding/'
  47. 'glove.6B.50d_test.txt')
  48. # 测试是否只使用pretrain的word
  49. vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
  50. vocab.add_word('of', no_create_entry=True)
  51. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  52. 'glove.6B.50d_test.txt',
  53. only_use_pretrain_word=True)
  54. # notinfile应该被置为unk
  55. check_vector_equal(['the', 'a', 'of'], vocab, embed, embed_dict)
  56. check_word_unk(['notinfile'], vocab, embed)
  57. # 测试在大小写情况下的使用
  58. vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile'])
  59. vocab.add_word('Of', no_create_entry=True)
  60. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  61. 'glove.6B.50d_test.txt',
  62. only_use_pretrain_word=True)
  63. check_word_unk(['The', 'Of', 'notinfile'], vocab, embed) # 这些词应该找不到
  64. check_vector_equal(['a'], vocab, embed, embed_dict)
  65. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  66. 'glove.6B.50d_test.txt',
  67. only_use_pretrain_word=True, lower=True)
  68. check_vector_equal(['The', 'Of', 'a'], vocab, embed, embed_dict, lower=True)
  69. check_word_unk(['notinfile'], vocab, embed)
  70. # 测试min_freq
  71. vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2'])
  72. vocab.add_word('Of', no_create_entry=True)
  73. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  74. 'glove.6B.50d_test.txt',
  75. only_use_pretrain_word=True, lower=True, min_freq=2, only_train_min_freq=True)
  76. check_vector_equal(['Of', 'a'], vocab, embed, embed_dict, lower=True)
  77. check_word_unk(['notinfile1', 'The', 'notinfile2'], vocab, embed)
  78. def test_sequential_index(self):
  79. # 当不存在no_create_entry时,words_to_words应该是顺序的
  80. vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2'])
  81. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  82. 'glove.6B.50d_test.txt')
  83. for index,i in enumerate(embed.words_to_words):
  84. assert index==i
  85. embed_dict = read_static_embed('test/data_for_tests/embedding/small_static_embedding/'
  86. 'glove.6B.50d_test.txt')
  87. for word, index in vocab:
  88. if word in embed_dict:
  89. index = vocab.to_index(word)
  90. v1 = embed(torch.LongTensor([index])).tolist()[0]
  91. v2 = embed_dict[word]
  92. for v1i, v2i in zip(v1, v2):
  93. self.assertAlmostEqual(v1i, v2i, places=4)
  94. def test_save_load_static_embed(self):
  95. static_test_folder = 'static_save_test'
  96. try:
  97. # 测试包含no_create_entry
  98. os.makedirs(static_test_folder, exist_ok=True)
  99. vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A'])
  100. vocab.add_word_lst(['notinfile2', 'notinfile2'], no_create_entry=True)
  101. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  102. 'glove.6B.50d_test.txt')
  103. embed.save(static_test_folder)
  104. load_embed = StaticEmbedding.load(static_test_folder)
  105. words = torch.randint(len(vocab), size=(2, 20))
  106. self.assertEqual((embed(words) - load_embed(words)).sum(), 0)
  107. # 测试不包含no_create_entry
  108. vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A'])
  109. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  110. 'glove.6B.50d_test.txt')
  111. embed.save(static_test_folder)
  112. load_embed = StaticEmbedding.load(static_test_folder)
  113. words = torch.randint(len(vocab), size=(2, 20))
  114. self.assertEqual((embed(words) - load_embed(words)).sum(), 0)
  115. # 测试lower, min_freq
  116. vocab = Vocabulary().add_word_lst(['The', 'the', 'the', 'A', 'a', 'B'])
  117. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  118. 'glove.6B.50d_test.txt', min_freq=2, lower=True)
  119. embed.save(static_test_folder)
  120. load_embed = StaticEmbedding.load(static_test_folder)
  121. words = torch.randint(len(vocab), size=(2, 20))
  122. self.assertEqual((embed(words) - load_embed(words)).sum(), 0)
  123. # 测试random的embedding
  124. vocab = Vocabulary().add_word_lst(['The', 'the', 'the', 'A', 'a', 'B'])
  125. vocab = vocab.add_word_lst(['b'], no_create_entry=True)
  126. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=4, min_freq=2, lower=True,
  127. normalize=True)
  128. embed.weight.data += 0.2 # 使得它不是normalize
  129. embed.save(static_test_folder)
  130. load_embed = StaticEmbedding.load(static_test_folder)
  131. words = torch.randint(len(vocab), size=(2, 20))
  132. self.assertEqual((embed(words) - load_embed(words)).sum(), 0)
  133. finally:
  134. if os.path.isdir(static_test_folder):
  135. import shutil
  136. shutil.rmtree(static_test_folder)
  137. def read_static_embed(fp):
  138. """
  139. :param str fp: embedding的路径
  140. :return: {}, key是word, value是vector
  141. """
  142. embed = {}
  143. with open(fp, 'r') as f:
  144. for line in f:
  145. line = line.strip()
  146. if line:
  147. parts = line.split()
  148. vector = list(map(float, parts[1:]))
  149. word = parts[0]
  150. embed[word] = vector
  151. return embed
  152. class TestRandomSameEntry(unittest.TestCase):
  153. def test_same_vector(self):
  154. vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"])
  155. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True)
  156. words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'a', 'A']]])
  157. words = embed(words)
  158. embed_0 = words[0, 0]
  159. for i in range(1, 3):
  160. assert torch.sum(embed_0==words[0, i]).eq(len(embed_0))
  161. embed_0 = words[0, 3]
  162. for i in range(3, 5):
  163. assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))
  164. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  165. def test_same_vector2(self):
  166. vocab = Vocabulary().add_word_lst(["The", 'a', 'b', "the", "THE", "B", 'a', "A"])
  167. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
  168. lower=True)
  169. words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'b', "B", 'a', 'A']]])
  170. words = embed(words)
  171. embed_0 = words[0, 0]
  172. for i in range(1, 3):
  173. assert torch.sum(embed_0==words[0, i]).eq(len(embed_0))
  174. embed_0 = words[0, 3]
  175. for i in range(3, 5):
  176. assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))
  177. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  178. def test_same_vector3(self):
  179. # 验证lower
  180. word_lst = ["The", "the"]
  181. no_create_word_lst = ['of', 'Of', 'With', 'with']
  182. vocab = Vocabulary().add_word_lst(word_lst)
  183. vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  184. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
  185. lower=True)
  186. words = torch.LongTensor([[vocab.to_index(word) for word in word_lst+no_create_word_lst]])
  187. words = embed(words)
  188. lowered_word_lst = [word.lower() for word in word_lst]
  189. lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst]
  190. lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst)
  191. lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True)
  192. lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d',
  193. lower=False)
  194. lowered_words = torch.LongTensor([[lowered_vocab.to_index(word) for word in lowered_word_lst+lowered_no_create_word_lst]])
  195. lowered_words = lowered_embed(lowered_words)
  196. all_words = word_lst + no_create_word_lst
  197. for idx, (word_i, word_j) in enumerate(zip(words[0], lowered_words[0])):
  198. with self.subTest(idx=idx, word=all_words[idx]):
  199. assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size)
  200. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  201. def test_same_vector4(self):
  202. # 验证在有min_freq下的lower
  203. word_lst = ["The", "the", "the", "The", "a", "A"]
  204. no_create_word_lst = ['of', 'Of', "Of", "of", 'With', 'with']
  205. all_words = word_lst[:-2] + no_create_word_lst[:-2]
  206. vocab = Vocabulary(min_freq=2).add_word_lst(word_lst)
  207. vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  208. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
  209. lower=True)
  210. words = torch.LongTensor([[vocab.to_index(word) for word in all_words]])
  211. words = embed(words)
  212. lowered_word_lst = [word.lower() for word in word_lst]
  213. lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst]
  214. lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst)
  215. lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True)
  216. lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d',
  217. lower=False)
  218. lowered_words = torch.LongTensor([[lowered_vocab.to_index(word.lower()) for word in all_words]])
  219. lowered_words = lowered_embed(lowered_words)
  220. for idx in range(len(all_words)):
  221. word_i, word_j = words[0, idx], lowered_words[0, idx]
  222. with self.subTest(idx=idx, word=all_words[idx]):
  223. assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size)
  224. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  225. def test_same_vector5(self):
  226. # 检查通过使用min_freq后的word是否内容一致
  227. word_lst = ["they", "the", "they", "the", 'he', 'he', "a", "A"]
  228. no_create_word_lst = ['of', "of", "she", "she", 'With', 'with']
  229. all_words = word_lst[:-2] + no_create_word_lst[:-2]
  230. vocab = Vocabulary().add_word_lst(word_lst)
  231. vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  232. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
  233. lower=False, min_freq=2)
  234. words = torch.LongTensor([[vocab.to_index(word) for word in all_words]])
  235. words = embed(words)
  236. min_freq_vocab = Vocabulary(min_freq=2).add_word_lst(word_lst)
  237. min_freq_vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  238. min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='en-glove-6B-100d',
  239. lower=False)
  240. min_freq_words = torch.LongTensor([[min_freq_vocab.to_index(word.lower()) for word in all_words]])
  241. min_freq_words = min_freq_embed(min_freq_words)
  242. for idx in range(len(all_words)):
  243. word_i, word_j = words[0, idx], min_freq_words[0, idx]
  244. with self.subTest(idx=idx, word=all_words[idx]):
  245. assert torch.sum(word_i == word_j).eq(min_freq_embed.embed_size)