You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_static_embedding.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import unittest
  2. from fastNLP.embeddings import StaticEmbedding
  3. from fastNLP import Vocabulary
  4. import torch
  5. import os
  6. class TestLoad(unittest.TestCase):
  7. def test_norm1(self):
  8. # 测试只对可以找到的norm
  9. vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
  10. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  11. 'glove.6B.50d_test.txt',
  12. only_norm_found_vector=True)
  13. self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1)
  14. self.assertNotEqual(torch.norm(embed(torch.LongTensor([[4]]))).item(), 1)
  15. def test_norm2(self):
  16. # 测试对所有都norm
  17. vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
  18. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  19. 'glove.6B.50d_test.txt',
  20. normalize=True)
  21. self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1)
  22. self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1)
  23. def test_dropword(self):
  24. # 测试是否可以通过drop word
  25. vocab = Vocabulary().add_word_lst([chr(i) for i in range(1, 200)])
  26. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10, dropout=0.1, word_dropout=0.4)
  27. for i in range(10):
  28. length = torch.randint(1, 50, (1,)).item()
  29. batch = torch.randint(1, 4, (1,)).item()
  30. words = torch.randint(1, 200, (batch, length)).long()
  31. embed(words)
  32. def test_only_use_pretrain_word(self):
  33. def check_word_unk(words, vocab, embed):
  34. for word in words:
  35. self.assertListEqual(embed(torch.LongTensor([vocab.to_index(word)])).tolist()[0],
  36. embed(torch.LongTensor([1])).tolist()[0])
  37. def check_vector_equal(words, vocab, embed, embed_dict, lower=False):
  38. for word in words:
  39. index = vocab.to_index(word)
  40. v1 = embed(torch.LongTensor([index])).tolist()[0]
  41. if lower:
  42. word = word.lower()
  43. v2 = embed_dict[word]
  44. for v1i, v2i in zip(v1, v2):
  45. self.assertAlmostEqual(v1i, v2i, places=4)
  46. embed_dict = read_static_embed('test/data_for_tests/embedding/small_static_embedding/'
  47. 'glove.6B.50d_test.txt')
  48. # 测试是否只使用pretrain的word
  49. vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
  50. vocab.add_word('of', no_create_entry=True)
  51. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  52. 'glove.6B.50d_test.txt',
  53. only_use_pretrain_word=True)
  54. # notinfile应该被置为unk
  55. check_vector_equal(['the', 'a', 'of'], vocab, embed, embed_dict)
  56. check_word_unk(['notinfile'], vocab, embed)
  57. # 测试在大小写情况下的使用
  58. vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile'])
  59. vocab.add_word('Of', no_create_entry=True)
  60. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  61. 'glove.6B.50d_test.txt',
  62. only_use_pretrain_word=True)
  63. check_word_unk(['The', 'Of', 'notinfile'], vocab, embed) # 这些词应该找不到
  64. check_vector_equal(['a'], vocab, embed, embed_dict)
  65. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  66. 'glove.6B.50d_test.txt',
  67. only_use_pretrain_word=True, lower=True)
  68. check_vector_equal(['The', 'Of', 'a'], vocab, embed, embed_dict, lower=True)
  69. check_word_unk(['notinfile'], vocab, embed)
  70. # 测试min_freq
  71. vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2'])
  72. vocab.add_word('Of', no_create_entry=True)
  73. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  74. 'glove.6B.50d_test.txt',
  75. only_use_pretrain_word=True, lower=True, min_freq=2, only_train_min_freq=True)
  76. check_vector_equal(['Of', 'a'], vocab, embed, embed_dict, lower=True)
  77. check_word_unk(['notinfile1', 'The', 'notinfile2'], vocab, embed)
  78. def test_sequential_index(self):
  79. # 当不存在no_create_entry时,words_to_words应该是顺序的
  80. vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2'])
  81. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  82. 'glove.6B.50d_test.txt')
  83. for index,i in enumerate(embed.words_to_words):
  84. assert index==i
  85. embed_dict = read_static_embed('test/data_for_tests/embedding/small_static_embedding/'
  86. 'glove.6B.50d_test.txt')
  87. for word, index in vocab:
  88. if word in embed_dict:
  89. index = vocab.to_index(word)
  90. v1 = embed(torch.LongTensor([index])).tolist()[0]
  91. v2 = embed_dict[word]
  92. for v1i, v2i in zip(v1, v2):
  93. self.assertAlmostEqual(v1i, v2i, places=4)
  94. def read_static_embed(fp):
  95. """
  96. :param str fp: embedding的路径
  97. :return: {}, key是word, value是vector
  98. """
  99. embed = {}
  100. with open(fp, 'r') as f:
  101. for line in f:
  102. line = line.strip()
  103. if line:
  104. parts = line.split()
  105. vector = list(map(float, parts[1:]))
  106. word = parts[0]
  107. embed[word] = vector
  108. return embed
  109. class TestRandomSameEntry(unittest.TestCase):
  110. def test_same_vector(self):
  111. vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"])
  112. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True)
  113. words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'a', 'A']]])
  114. words = embed(words)
  115. embed_0 = words[0, 0]
  116. for i in range(1, 3):
  117. assert torch.sum(embed_0==words[0, i]).eq(len(embed_0))
  118. embed_0 = words[0, 3]
  119. for i in range(3, 5):
  120. assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))
  121. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  122. def test_same_vector2(self):
  123. vocab = Vocabulary().add_word_lst(["The", 'a', 'b', "the", "THE", "B", 'a', "A"])
  124. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
  125. lower=True)
  126. words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'b', "B", 'a', 'A']]])
  127. words = embed(words)
  128. embed_0 = words[0, 0]
  129. for i in range(1, 3):
  130. assert torch.sum(embed_0==words[0, i]).eq(len(embed_0))
  131. embed_0 = words[0, 3]
  132. for i in range(3, 5):
  133. assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))
  134. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  135. def test_same_vector3(self):
  136. # 验证lower
  137. word_lst = ["The", "the"]
  138. no_create_word_lst = ['of', 'Of', 'With', 'with']
  139. vocab = Vocabulary().add_word_lst(word_lst)
  140. vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  141. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
  142. lower=True)
  143. words = torch.LongTensor([[vocab.to_index(word) for word in word_lst+no_create_word_lst]])
  144. words = embed(words)
  145. lowered_word_lst = [word.lower() for word in word_lst]
  146. lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst]
  147. lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst)
  148. lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True)
  149. lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d',
  150. lower=False)
  151. lowered_words = torch.LongTensor([[lowered_vocab.to_index(word) for word in lowered_word_lst+lowered_no_create_word_lst]])
  152. lowered_words = lowered_embed(lowered_words)
  153. all_words = word_lst + no_create_word_lst
  154. for idx, (word_i, word_j) in enumerate(zip(words[0], lowered_words[0])):
  155. with self.subTest(idx=idx, word=all_words[idx]):
  156. assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size)
  157. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  158. def test_same_vector4(self):
  159. # 验证在有min_freq下的lower
  160. word_lst = ["The", "the", "the", "The", "a", "A"]
  161. no_create_word_lst = ['of', 'Of', "Of", "of", 'With', 'with']
  162. all_words = word_lst[:-2] + no_create_word_lst[:-2]
  163. vocab = Vocabulary(min_freq=2).add_word_lst(word_lst)
  164. vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  165. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
  166. lower=True)
  167. words = torch.LongTensor([[vocab.to_index(word) for word in all_words]])
  168. words = embed(words)
  169. lowered_word_lst = [word.lower() for word in word_lst]
  170. lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst]
  171. lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst)
  172. lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True)
  173. lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d',
  174. lower=False)
  175. lowered_words = torch.LongTensor([[lowered_vocab.to_index(word.lower()) for word in all_words]])
  176. lowered_words = lowered_embed(lowered_words)
  177. for idx in range(len(all_words)):
  178. word_i, word_j = words[0, idx], lowered_words[0, idx]
  179. with self.subTest(idx=idx, word=all_words[idx]):
  180. assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size)
  181. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  182. def test_same_vector5(self):
  183. # 检查通过使用min_freq后的word是否内容一致
  184. word_lst = ["they", "the", "they", "the", 'he', 'he', "a", "A"]
  185. no_create_word_lst = ['of', "of", "she", "she", 'With', 'with']
  186. all_words = word_lst[:-2] + no_create_word_lst[:-2]
  187. vocab = Vocabulary().add_word_lst(word_lst)
  188. vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  189. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
  190. lower=False, min_freq=2)
  191. words = torch.LongTensor([[vocab.to_index(word) for word in all_words]])
  192. words = embed(words)
  193. min_freq_vocab = Vocabulary(min_freq=2).add_word_lst(word_lst)
  194. min_freq_vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  195. min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='en-glove-6B-100d',
  196. lower=False)
  197. min_freq_words = torch.LongTensor([[min_freq_vocab.to_index(word.lower()) for word in all_words]])
  198. min_freq_words = min_freq_embed(min_freq_words)
  199. for idx in range(len(all_words)):
  200. word_i, word_j = words[0, idx], min_freq_words[0, idx]
  201. with self.subTest(idx=idx, word=all_words[idx]):
  202. assert torch.sum(word_i == word_j).eq(min_freq_embed.embed_size)