You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_embed_loader.py 1.6 kB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import unittest
  2. import numpy as np
  3. from fastNLP.core.vocabulary import Vocabulary
  4. from fastNLP.io.embed_loader import EmbedLoader
  5. class TestEmbedLoader(unittest.TestCase):
  6. def test_load_with_vocab(self):
  7. vocab = Vocabulary()
  8. glove = "test/data_for_tests/glove.6B.50d_test.txt"
  9. word2vec = "test/data_for_tests/word2vec_test.txt"
  10. vocab.add_word('the')
  11. vocab.add_word('none')
  12. g_m = EmbedLoader.load_with_vocab(glove, vocab)
  13. self.assertEqual(g_m.shape, (4, 50))
  14. w_m = EmbedLoader.load_with_vocab(word2vec, vocab, normalize=True)
  15. self.assertEqual(w_m.shape, (4, 50))
  16. self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 4)
  17. def test_load_without_vocab(self):
  18. words = ['the', 'of', 'in', 'a', 'to', 'and']
  19. glove = "test/data_for_tests/glove.6B.50d_test.txt"
  20. word2vec = "test/data_for_tests/word2vec_test.txt"
  21. g_m, vocab = EmbedLoader.load_without_vocab(glove)
  22. self.assertEqual(g_m.shape, (8, 50))
  23. for word in words:
  24. self.assertIn(word, vocab)
  25. w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True)
  26. self.assertEqual(w_m.shape, (8, 50))
  27. self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 8)
  28. for word in words:
  29. self.assertIn(word, vocab)
  30. # no unk
  31. w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True, unknown=None)
  32. self.assertEqual(w_m.shape, (7, 50))
  33. self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 7)
  34. for word in words:
  35. self.assertIn(word, vocab)