You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_embed_loader.py 1.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import unittest
  2. import numpy as np
  3. from fastNLP.core.vocabulary import Vocabulary
  4. from fastNLP.io.embed_loader import EmbedLoader
  5. class TestEmbedLoader(unittest.TestCase):
  6. def test_case(self):
  7. vocab = Vocabulary()
  8. vocab.update(["the", "in", "I", "to", "of", "hahaha"])
  9. embedding = EmbedLoader().fast_load_embedding(50, "test/data_for_tests/glove.6B.50d_test.txt", vocab)
  10. self.assertEqual(tuple(embedding.shape), (len(vocab), 50))
  11. def test_load_with_vocab(self):
  12. vocab = Vocabulary()
  13. glove = "test/data_for_tests/glove.6B.50d_test.txt"
  14. word2vec = "test/data_for_tests/word2vec_test.txt"
  15. vocab.add_word('the')
  16. vocab.add_word('none')
  17. g_m = EmbedLoader.load_with_vocab(glove, vocab)
  18. self.assertEqual(g_m.shape, (4, 50))
  19. w_m = EmbedLoader.load_with_vocab(word2vec, vocab, normalize=True)
  20. self.assertEqual(w_m.shape, (4, 50))
  21. self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 4)
  22. def test_load_without_vocab(self):
  23. words = ['the', 'of', 'in', 'a', 'to', 'and']
  24. glove = "test/data_for_tests/glove.6B.50d_test.txt"
  25. word2vec = "test/data_for_tests/word2vec_test.txt"
  26. g_m, vocab = EmbedLoader.load_without_vocab(glove)
  27. self.assertEqual(g_m.shape, (8, 50))
  28. for word in words:
  29. self.assertIn(word, vocab)
  30. w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True)
  31. self.assertEqual(w_m.shape, (8, 50))
  32. self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 8)
  33. for word in words:
  34. self.assertIn(word, vocab)
  35. # no unk
  36. w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True, unknown=None)
  37. self.assertEqual(w_m.shape, (7, 50))
  38. self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 7)
  39. for word in words:
  40. self.assertIn(word, vocab)