You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bert_embedding.py 475 B

1234567891011121314
  1. import unittest
  2. from fastNLP import Vocabulary
  3. from fastNLP.embeddings import BertEmbedding
  4. import torch
  5. import os
  6. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  7. class TestDownload(unittest.TestCase):
  8. def test_download(self):
  9. # import os
  10. vocab = Vocabulary().add_word_lst("This is a test .".split())
  11. embed = BertEmbedding(vocab, model_dir_or_name='en')
  12. words = torch.LongTensor([[0, 1, 2]])
  13. print(embed(words).size())