You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_static_embedding.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import unittest
  2. from fastNLP.embeddings import StaticEmbedding
  3. from fastNLP import Vocabulary
  4. import torch
  5. import os
  6. class TestLoad(unittest.TestCase):
  7. def test_norm1(self):
  8. # 测试只对可以找到的norm
  9. vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
  10. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  11. 'glove.6B.50d_test.txt',
  12. only_norm_found_vector=True)
  13. self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1)
  14. self.assertNotEqual(torch.norm(embed(torch.LongTensor([[4]]))).item(), 1)
  15. def test_norm2(self):
  16. # 测试对所有都norm
  17. vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
  18. embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/'
  19. 'glove.6B.50d_test.txt',
  20. normalize=True)
  21. self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1)
  22. self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1)
  23. def test_dropword(self):
  24. # 测试是否可以通过drop word
  25. vocab = Vocabulary().add_word_lst([chr(i) for i in range(1, 200)])
  26. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10, dropout=0.1, word_dropout=0.4)
  27. for i in range(10):
  28. length = torch.randint(1, 50, (1,)).item()
  29. batch = torch.randint(1, 4, (1,)).item()
  30. words = torch.randint(1, 200, (batch, length)).long()
  31. embed(words)
  32. class TestRandomSameEntry(unittest.TestCase):
  33. def test_same_vector(self):
  34. vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"])
  35. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True)
  36. words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'a', 'A']]])
  37. words = embed(words)
  38. embed_0 = words[0, 0]
  39. for i in range(1, 3):
  40. assert torch.sum(embed_0==words[0, i]).eq(len(embed_0))
  41. embed_0 = words[0, 3]
  42. for i in range(3, 5):
  43. assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))
  44. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  45. def test_same_vector2(self):
  46. vocab = Vocabulary().add_word_lst(["The", 'a', 'b', "the", "THE", "B", 'a', "A"])
  47. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
  48. lower=True)
  49. words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'b', "B", 'a', 'A']]])
  50. words = embed(words)
  51. embed_0 = words[0, 0]
  52. for i in range(1, 3):
  53. assert torch.sum(embed_0==words[0, i]).eq(len(embed_0))
  54. embed_0 = words[0, 3]
  55. for i in range(3, 5):
  56. assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))
  57. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  58. def test_same_vector3(self):
  59. # 验证lower
  60. word_lst = ["The", "the"]
  61. no_create_word_lst = ['of', 'Of', 'With', 'with']
  62. vocab = Vocabulary().add_word_lst(word_lst)
  63. vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  64. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
  65. lower=True)
  66. words = torch.LongTensor([[vocab.to_index(word) for word in word_lst+no_create_word_lst]])
  67. words = embed(words)
  68. lowered_word_lst = [word.lower() for word in word_lst]
  69. lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst]
  70. lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst)
  71. lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True)
  72. lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d',
  73. lower=False)
  74. lowered_words = torch.LongTensor([[lowered_vocab.to_index(word) for word in lowered_word_lst+lowered_no_create_word_lst]])
  75. lowered_words = lowered_embed(lowered_words)
  76. all_words = word_lst + no_create_word_lst
  77. for idx, (word_i, word_j) in enumerate(zip(words[0], lowered_words[0])):
  78. with self.subTest(idx=idx, word=all_words[idx]):
  79. assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size)
  80. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  81. def test_same_vector4(self):
  82. # 验证在有min_freq下的lower
  83. word_lst = ["The", "the", "the", "The", "a", "A"]
  84. no_create_word_lst = ['of', 'Of', "Of", "of", 'With', 'with']
  85. all_words = word_lst[:-2] + no_create_word_lst[:-2]
  86. vocab = Vocabulary(min_freq=2).add_word_lst(word_lst)
  87. vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  88. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
  89. lower=True)
  90. words = torch.LongTensor([[vocab.to_index(word) for word in all_words]])
  91. words = embed(words)
  92. lowered_word_lst = [word.lower() for word in word_lst]
  93. lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst]
  94. lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst)
  95. lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True)
  96. lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d',
  97. lower=False)
  98. lowered_words = torch.LongTensor([[lowered_vocab.to_index(word.lower()) for word in all_words]])
  99. lowered_words = lowered_embed(lowered_words)
  100. for idx in range(len(all_words)):
  101. word_i, word_j = words[0, idx], lowered_words[0, idx]
  102. with self.subTest(idx=idx, word=all_words[idx]):
  103. assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size)
  104. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  105. def test_same_vector5(self):
  106. # 检查通过使用min_freq后的word是否内容一致
  107. word_lst = ["they", "the", "they", "the", 'he', 'he', "a", "A"]
  108. no_create_word_lst = ['of', "of", "she", "she", 'With', 'with']
  109. all_words = word_lst[:-2] + no_create_word_lst[:-2]
  110. vocab = Vocabulary().add_word_lst(word_lst)
  111. vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  112. embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
  113. lower=False, min_freq=2)
  114. words = torch.LongTensor([[vocab.to_index(word) for word in all_words]])
  115. words = embed(words)
  116. min_freq_vocab = Vocabulary(min_freq=2).add_word_lst(word_lst)
  117. min_freq_vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  118. min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='en-glove-6B-100d',
  119. lower=False)
  120. min_freq_words = torch.LongTensor([[min_freq_vocab.to_index(word.lower()) for word in all_words]])
  121. min_freq_words = min_freq_embed(min_freq_words)
  122. for idx in range(len(all_words)):
  123. word_i, word_j = words[0, idx], min_freq_words[0, idx]
  124. with self.subTest(idx=idx, word=all_words[idx]):
  125. assert torch.sum(word_i == word_j).eq(min_freq_embed.embed_size)