You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_elmo_embedding.py 1.4 kB

12345678910111213141516171819202122232425262728293031323334353637
  1. import unittest
  2. from fastNLP import Vocabulary
  3. from fastNLP.embeddings import ElmoEmbedding
  4. import torch
  5. import os
  6. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  7. class TestDownload(unittest.TestCase):
  8. def test_download_small(self):
  9. # import os
  10. vocab = Vocabulary().add_word_lst("This is a test .".split())
  11. elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='en-small')
  12. words = torch.LongTensor([[0, 1, 2]])
  13. print(elmo_embed(words).size())
  14. # 首先保证所有权重可以加载;上传权重;验证可以下载
  15. class TestRunElmo(unittest.TestCase):
  16. def test_elmo_embedding(self):
  17. vocab = Vocabulary().add_word_lst("This is a test .".split())
  18. elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_elmo', layers='0,1')
  19. words = torch.LongTensor([[0, 1, 2]])
  20. hidden = elmo_embed(words)
  21. print(hidden.size())
  22. self.assertEqual(hidden.size(), (1, 3, elmo_embed.embedding_dim))
  23. def test_elmo_embedding_layer_assertion(self):
  24. vocab = Vocabulary().add_word_lst("This is a test .".split())
  25. try:
  26. elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_elmo',
  27. layers='0,1,2')
  28. except AssertionError as e:
  29. print(e)