You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_elmo_embedding.py 1.3 kB

123456789101112131415161718192021222324252627282930313233343536
  1. import unittest
  2. from fastNLP import Vocabulary
  3. from fastNLP.embeddings import ElmoEmbedding
  4. import torch
  5. import os
  6. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  7. class TestDownload(unittest.TestCase):
  8. def test_download_small(self):
  9. # import os
  10. vocab = Vocabulary().add_word_lst("This is a test .".split())
  11. elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='en-small')
  12. words = torch.LongTensor([[0, 1, 2]])
  13. print(elmo_embed(words).size())
  14. # 首先保证所有权重可以加载;上传权重;验证可以下载
  15. class TestRunElmo(unittest.TestCase):
  16. def test_elmo_embedding(self):
  17. vocab = Vocabulary().add_word_lst("This is a test .".split())
  18. elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_elmo', layers='0,1')
  19. words = torch.LongTensor([[0, 1, 2]])
  20. hidden = elmo_embed(words)
  21. print(hidden.size())
  22. def test_elmo_embedding_layer_assertion(self):
  23. vocab = Vocabulary().add_word_lst("This is a test .".split())
  24. try:
  25. elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_elmo',
  26. layers='0,1,2')
  27. except AssertionError as e:
  28. print(e)