You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_static_embedding.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import unittest
  2. from fastNLP.embeddings import StaticEmbedding
  3. from fastNLP import Vocabulary
  4. import torch
  5. import os
  6. class TestLoad(unittest.TestCase):
  7. def test_norm1(self):
  8. # 测试只对可以找到的norm
  9. vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
  10. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  11. 'glove.6B.50d_test.txt',
  12. only_norm_found_vector=True)
  13. self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1)
  14. self.assertNotEqual(torch.norm(embed(torch.LongTensor([[4]]))).item(), 1)
  15. def test_norm2(self):
  16. # 测试对所有都norm
  17. vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
  18. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  19. 'glove.6B.50d_test.txt',
  20. normalize=True)
  21. self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1)
  22. self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1)
  23. def test_dropword(self):
  24. # 测试是否可以通过drop word
  25. vocab = Vocabulary().add_word_lst([chr(i) for i in range(1, 200)])
  26. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10, dropout=0.1, word_dropout=0.4)
  27. for i in range(10):
  28. length = torch.randint(1, 50, (1,)).item()
  29. batch = torch.randint(1, 4, (1,)).item()
  30. words = torch.randint(1, 200, (batch, length)).long()
  31. embed(words)
  32. def test_only_use_pretrain_word(self):
  33. def check_word_unk(words, vocab, embed):
  34. for word in words:
  35. self.assertListEqual(embed(torch.LongTensor([vocab.to_index(word)])).tolist()[0],
  36. embed(torch.LongTensor([1])).tolist()[0])
  37. def check_vector_equal(words, vocab, embed, embed_dict, lower=False):
  38. for word in words:
  39. index = vocab.to_index(word)
  40. v1 = embed(torch.LongTensor([index])).tolist()[0]
  41. if lower:
  42. word = word.lower()
  43. v2 = embed_dict[word]
  44. for v1i, v2i in zip(v1, v2):
  45. self.assertAlmostEqual(v1i, v2i, places=4)
  46. embed_dict = read_static_embed('test/data_for_tests/embedding/small_static_embedding/'
  47. 'glove.6B.50d_test.txt')
  48. # 测试是否只使用pretrain的word
  49. vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
  50. vocab.add_word('of', no_create_entry=True)
  51. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  52. 'glove.6B.50d_test.txt',
  53. only_use_pretrain_word=True)
  54. # notinfile应该被置为unk
  55. check_vector_equal(['the', 'a', 'of'], vocab, embed, embed_dict)
  56. check_word_unk(['notinfile'], vocab, embed)
  57. # 测试在大小写情况下的使用
  58. vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile'])
  59. vocab.add_word('Of', no_create_entry=True)
  60. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  61. 'glove.6B.50d_test.txt',
  62. only_use_pretrain_word=True)
  63. check_word_unk(['The', 'Of', 'notinfile'], vocab, embed) # 这些词应该找不到
  64. check_vector_equal(['a'], vocab, embed, embed_dict)
  65. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  66. 'glove.6B.50d_test.txt',
  67. only_use_pretrain_word=True, lower=True)
  68. check_vector_equal(['The', 'Of', 'a'], vocab, embed, embed_dict, lower=True)
  69. check_word_unk(['notinfile'], vocab, embed)
  70. # 测试min_freq
  71. vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A', 'notinfile2', 'notinfile2'])
  72. vocab.add_word('Of', no_create_entry=True)
  73. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  74. 'glove.6B.50d_test.txt',
  75. only_use_pretrain_word=True, lower=True, min_freq=2, only_train_min_freq=True)
  76. check_vector_equal(['Of', 'a'], vocab, embed, embed_dict, lower=True)
  77. check_word_unk(['notinfile1', 'The', 'notinfile2'], vocab, embed)
  78. def read_static_embed(fp):
  79. """
  80. :param str fp: embedding的路径
  81. :return: {}, key是word, value是vector
  82. """
  83. embed = {}
  84. with open(fp, 'r') as f:
  85. for line in f:
  86. line = line.strip()
  87. if line:
  88. parts = line.split()
  89. vector = list(map(float, parts[1:]))
  90. word = parts[0]
  91. embed[word] = vector
  92. return embed
  93. class TestRandomSameEntry(unittest.TestCase):
  94. def test_same_vector(self):
  95. vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"])
  96. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True)
  97. words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'a', 'A']]])
  98. words = embed(words)
  99. embed_0 = words[0, 0]
  100. for i in range(1, 3):
  101. assert torch.sum(embed_0==words[0, i]).eq(len(embed_0))
  102. embed_0 = words[0, 3]
  103. for i in range(3, 5):
  104. assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))
  105. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  106. def test_same_vector2(self):
  107. vocab = Vocabulary().add_word_lst(["The", 'a', 'b', "the", "THE", "B", 'a', "A"])
  108. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
  109. lower=True)
  110. words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'b', "B", 'a', 'A']]])
  111. words = embed(words)
  112. embed_0 = words[0, 0]
  113. for i in range(1, 3):
  114. assert torch.sum(embed_0==words[0, i]).eq(len(embed_0))
  115. embed_0 = words[0, 3]
  116. for i in range(3, 5):
  117. assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))
  118. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  119. def test_same_vector3(self):
  120. # 验证lower
  121. word_lst = ["The", "the"]
  122. no_create_word_lst = ['of', 'Of', 'With', 'with']
  123. vocab = Vocabulary().add_word_lst(word_lst)
  124. vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  125. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
  126. lower=True)
  127. words = torch.LongTensor([[vocab.to_index(word) for word in word_lst+no_create_word_lst]])
  128. words = embed(words)
  129. lowered_word_lst = [word.lower() for word in word_lst]
  130. lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst]
  131. lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst)
  132. lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True)
  133. lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d',
  134. lower=False)
  135. lowered_words = torch.LongTensor([[lowered_vocab.to_index(word) for word in lowered_word_lst+lowered_no_create_word_lst]])
  136. lowered_words = lowered_embed(lowered_words)
  137. all_words = word_lst + no_create_word_lst
  138. for idx, (word_i, word_j) in enumerate(zip(words[0], lowered_words[0])):
  139. with self.subTest(idx=idx, word=all_words[idx]):
  140. assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size)
  141. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  142. def test_same_vector4(self):
  143. # 验证在有min_freq下的lower
  144. word_lst = ["The", "the", "the", "The", "a", "A"]
  145. no_create_word_lst = ['of', 'Of', "Of", "of", 'With', 'with']
  146. all_words = word_lst[:-2] + no_create_word_lst[:-2]
  147. vocab = Vocabulary(min_freq=2).add_word_lst(word_lst)
  148. vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  149. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
  150. lower=True)
  151. words = torch.LongTensor([[vocab.to_index(word) for word in all_words]])
  152. words = embed(words)
  153. lowered_word_lst = [word.lower() for word in word_lst]
  154. lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst]
  155. lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst)
  156. lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True)
  157. lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d',
  158. lower=False)
  159. lowered_words = torch.LongTensor([[lowered_vocab.to_index(word.lower()) for word in all_words]])
  160. lowered_words = lowered_embed(lowered_words)
  161. for idx in range(len(all_words)):
  162. word_i, word_j = words[0, idx], lowered_words[0, idx]
  163. with self.subTest(idx=idx, word=all_words[idx]):
  164. assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size)
  165. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  166. def test_same_vector5(self):
  167. # 检查通过使用min_freq后的word是否内容一致
  168. word_lst = ["they", "the", "they", "the", 'he', 'he', "a", "A"]
  169. no_create_word_lst = ['of', "of", "she", "she", 'With', 'with']
  170. all_words = word_lst[:-2] + no_create_word_lst[:-2]
  171. vocab = Vocabulary().add_word_lst(word_lst)
  172. vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  173. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
  174. lower=False, min_freq=2)
  175. words = torch.LongTensor([[vocab.to_index(word) for word in all_words]])
  176. words = embed(words)
  177. min_freq_vocab = Vocabulary(min_freq=2).add_word_lst(word_lst)
  178. min_freq_vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  179. min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='en-glove-6B-100d',
  180. lower=False)
  181. min_freq_words = torch.LongTensor([[min_freq_vocab.to_index(word.lower()) for word in all_words]])
  182. min_freq_words = min_freq_embed(min_freq_words)
  183. for idx in range(len(all_words)):
  184. word_i, word_j = words[0, idx], min_freq_words[0, idx]
  185. with self.subTest(idx=idx, word=all_words[idx]):
  186. assert torch.sum(word_i == word_j).eq(min_freq_embed.embed_size)