You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_embed_loader.py 1.1 kB

123456789101112131415161718192021222324252627282930313233
  1. import unittest
  2. import os
  3. import torch
  4. from fastNLP.loader.embed_loader import EmbedLoader
  5. from fastNLP.core.vocabulary import Vocabulary
  6. class TestEmbedLoader(unittest.TestCase):
  7. glove_path = './test/data_for_tests/glove.6B.50d_test.txt'
  8. pkl_path = './save'
  9. raw_texts = ["i am a cat",
  10. "this is a test of new batch",
  11. "ha ha",
  12. "I am a good boy .",
  13. "This is the most beautiful girl ."
  14. ]
  15. texts = [text.strip().split() for text in raw_texts]
  16. vocab = Vocabulary()
  17. vocab.update(texts)
  18. def test1(self):
  19. emb, _ = EmbedLoader.load_embedding(50, self.glove_path, 'glove', self.vocab, self.pkl_path)
  20. self.assertTrue(emb.shape[0] == (len(self.vocab)))
  21. self.assertTrue(emb.shape[1] == 50)
  22. os.remove(self.pkl_path)
  23. def test2(self):
  24. try:
  25. _ = EmbedLoader.load_embedding(100, self.glove_path, 'glove', self.vocab, self.pkl_path)
  26. self.fail(msg="load dismatch embedding")
  27. except ValueError:
  28. pass