You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_embed_loader.py 2.1 kB

7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import unittest
  2. import numpy as np
  3. from fastNLP import Vocabulary
  4. from fastNLP.io import EmbedLoader
  5. import os
  6. from fastNLP.io.dataset_loader import SSTLoader
  7. from fastNLP.core.const import Const as C
  8. class TestEmbedLoader(unittest.TestCase):
  9. def test_load_with_vocab(self):
  10. vocab = Vocabulary()
  11. glove = "../data_for_tests/glove.6B.50d_test.txt"
  12. word2vec = "../data_for_tests/word2vec_test.txt"
  13. vocab.add_word('the')
  14. vocab.add_word('none')
  15. g_m = EmbedLoader.load_with_vocab(glove, vocab)
  16. self.assertEqual(g_m.shape, (4, 50))
  17. w_m = EmbedLoader.load_with_vocab(word2vec, vocab, normalize=True)
  18. self.assertEqual(w_m.shape, (4, 50))
  19. self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 4)
  20. def test_load_without_vocab(self):
  21. words = ['the', 'of', 'in', 'a', 'to', 'and']
  22. glove = "../data_for_tests/glove.6B.50d_test.txt"
  23. word2vec = "../data_for_tests/word2vec_test.txt"
  24. g_m, vocab = EmbedLoader.load_without_vocab(glove)
  25. self.assertEqual(g_m.shape, (8, 50))
  26. for word in words:
  27. self.assertIn(word, vocab)
  28. w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True)
  29. self.assertEqual(w_m.shape, (8, 50))
  30. self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 8)
  31. for word in words:
  32. self.assertIn(word, vocab)
  33. # no unk
  34. w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True, unknown=None)
  35. self.assertEqual(w_m.shape, (7, 50))
  36. self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 7)
  37. for word in words:
  38. self.assertIn(word, vocab)
  39. def test_read_all_glove(self):
  40. pass
  41. # 这是可以运行的,但是总数少于行数,应该是由于glove有重复的word
  42. # path = '/where/to/read/full/glove'
  43. # init_embed, vocab = EmbedLoader.load_without_vocab(path, error='strict')
  44. # print(init_embed.shape)
  45. # print(init_embed.mean())
  46. # print(np.isnan(init_embed).sum())
  47. # print(len(vocab))