You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_static_embedding.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import unittest
  2. from fastNLP.embeddings import StaticEmbedding
  3. from fastNLP import Vocabulary
  4. import torch
  5. import os
  6. class TestRandomSameEntry(unittest.TestCase):
  7. def test_same_vector(self):
  8. vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"])
  9. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True)
  10. words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'a', 'A']]])
  11. words = embed(words)
  12. embed_0 = words[0, 0]
  13. for i in range(1, 3):
  14. assert torch.sum(embed_0==words[0, i]).eq(len(embed_0))
  15. embed_0 = words[0, 3]
  16. for i in range(3, 5):
  17. assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))
  18. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  19. def test_same_vector2(self):
  20. vocab = Vocabulary().add_word_lst(["The", 'a', 'b', "the", "THE", "B", 'a', "A"])
  21. embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.6B.100d.txt',
  22. lower=True)
  23. words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'b', "B", 'a', 'A']]])
  24. words = embed(words)
  25. embed_0 = words[0, 0]
  26. for i in range(1, 3):
  27. assert torch.sum(embed_0==words[0, i]).eq(len(embed_0))
  28. embed_0 = words[0, 3]
  29. for i in range(3, 5):
  30. assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))
  31. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  32. def test_same_vector3(self):
  33. # 验证lower
  34. word_lst = ["The", "the"]
  35. no_create_word_lst = ['of', 'Of', 'With', 'with']
  36. vocab = Vocabulary().add_word_lst(word_lst)
  37. vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  38. embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
  39. lower=True)
  40. words = torch.LongTensor([[vocab.to_index(word) for word in word_lst+no_create_word_lst]])
  41. words = embed(words)
  42. lowered_word_lst = [word.lower() for word in word_lst]
  43. lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst]
  44. lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst)
  45. lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True)
  46. lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
  47. lower=False)
  48. lowered_words = torch.LongTensor([[lowered_vocab.to_index(word) for word in lowered_word_lst+lowered_no_create_word_lst]])
  49. lowered_words = lowered_embed(lowered_words)
  50. all_words = word_lst + no_create_word_lst
  51. for idx, (word_i, word_j) in enumerate(zip(words[0], lowered_words[0])):
  52. with self.subTest(idx=idx, word=all_words[idx]):
  53. assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size)
  54. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  55. def test_same_vector4(self):
  56. # 验证在有min_freq下的lower
  57. word_lst = ["The", "the", "the", "The", "a", "A"]
  58. no_create_word_lst = ['of', 'Of', "Of", "of", 'With', 'with']
  59. all_words = word_lst[:-2] + no_create_word_lst[:-2]
  60. vocab = Vocabulary(min_freq=2).add_word_lst(word_lst)
  61. vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  62. embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
  63. lower=True)
  64. words = torch.LongTensor([[vocab.to_index(word) for word in all_words]])
  65. words = embed(words)
  66. lowered_word_lst = [word.lower() for word in word_lst]
  67. lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst]
  68. lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst)
  69. lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True)
  70. lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
  71. lower=False)
  72. lowered_words = torch.LongTensor([[lowered_vocab.to_index(word.lower()) for word in all_words]])
  73. lowered_words = lowered_embed(lowered_words)
  74. for idx in range(len(all_words)):
  75. word_i, word_j = words[0, idx], lowered_words[0, idx]
  76. with self.subTest(idx=idx, word=all_words[idx]):
  77. assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size)
  78. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  79. def test_same_vector5(self):
  80. # 检查通过使用min_freq后的word是否内容一致
  81. word_lst = ["they", "the", "they", "the", 'he', 'he', "a", "A"]
  82. no_create_word_lst = ['of', "of", "she", "she", 'With', 'with']
  83. all_words = word_lst[:-2] + no_create_word_lst[:-2]
  84. vocab = Vocabulary().add_word_lst(word_lst)
  85. vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  86. embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
  87. lower=False, min_freq=2)
  88. words = torch.LongTensor([[vocab.to_index(word) for word in all_words]])
  89. words = embed(words)
  90. min_freq_vocab = Vocabulary(min_freq=2).add_word_lst(word_lst)
  91. min_freq_vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
  92. min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
  93. lower=False)
  94. min_freq_words = torch.LongTensor([[min_freq_vocab.to_index(word.lower()) for word in all_words]])
  95. min_freq_words = min_freq_embed(min_freq_words)
  96. for idx in range(len(all_words)):
  97. word_i, word_j = words[0, idx], min_freq_words[0, idx]
  98. with self.subTest(idx=idx, word=all_words[idx]):
  99. assert torch.sum(word_i == word_j).eq(min_freq_embed.embed_size)