You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bert_embedding.py 1.6 kB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import unittest
  2. from fastNLP import Vocabulary
  3. from fastNLP.embeddings import BertEmbedding
  4. import torch
  5. import os
  6. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  7. class TestDownload(unittest.TestCase):
  8. def test_download(self):
  9. # import os
  10. vocab = Vocabulary().add_word_lst("This is a test .".split())
  11. embed = BertEmbedding(vocab, model_dir_or_name='en')
  12. words = torch.LongTensor([[2, 3, 4, 0]])
  13. print(embed(words).size())
  14. for pool_method in ['first', 'last', 'max', 'avg']:
  15. for include_cls_sep in [True, False]:
  16. embed = BertEmbedding(vocab, model_dir_or_name='en', pool_method=pool_method,
  17. include_cls_sep=include_cls_sep)
  18. print(embed(words).size())
  19. def test_word_drop(self):
  20. vocab = Vocabulary().add_word_lst("This is a test .".split())
  21. embed = BertEmbedding(vocab, model_dir_or_name='en', dropout=0.1, word_dropout=0.2)
  22. for i in range(10):
  23. words = torch.LongTensor([[2, 3, 4, 0]])
  24. print(embed(words).size())
  25. class TestBertEmbedding(unittest.TestCase):
  26. def test_bert_embedding_1(self):
  27. vocab = Vocabulary().add_word_lst("this is a test . [SEP]".split())
  28. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1)
  29. requires_grad = embed.requires_grad
  30. embed.requires_grad = not requires_grad
  31. embed.train()
  32. words = torch.LongTensor([[2, 3, 4, 0]])
  33. result = embed(words)
  34. self.assertEqual(result.size(), (1, 4, 16))