You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_embed_loader.py 1.1 kB

12345678910111213141516171819202122232425262728293031
  1. import os
  2. import unittest
  3. from fastNLP.core.vocabulary import Vocabulary
  4. from fastNLP.io.embed_loader import EmbedLoader
  5. class TestEmbedLoader(unittest.TestCase):
  6. glove_path = './test/data_for_tests/glove.6B.50d_test.txt'
  7. pkl_path = './save'
  8. raw_texts = ["i am a cat",
  9. "this is a test of new batch",
  10. "ha ha",
  11. "I am a good boy .",
  12. "This is the most beautiful girl ."
  13. ]
  14. texts = [text.strip().split() for text in raw_texts]
  15. vocab = Vocabulary()
  16. vocab.update(texts)
  17. def test1(self):
  18. emb, _ = EmbedLoader.load_embedding(50, self.glove_path, 'glove', self.vocab, self.pkl_path)
  19. self.assertTrue(emb.shape[0] == (len(self.vocab)))
  20. self.assertTrue(emb.shape[1] == 50)
  21. os.remove(self.pkl_path)
  22. def test2(self):
  23. try:
  24. _ = EmbedLoader.load_embedding(100, self.glove_path, 'glove', self.vocab, self.pkl_path)
  25. self.fail(msg="load dismatch embedding")
  26. except ValueError:
  27. pass