You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_transformer_embedding.py 1.7 kB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import unittest
  2. import torch
  3. import os
  4. from fastNLP import DataSet, Vocabulary
  5. from fastNLP.embeddings.transformers_embedding import TransformersEmbedding, TransformersWordPieceEncoder
  6. class TransformersEmbeddingTest(unittest.TestCase):
  7. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  8. def test_transformers_embedding_1(self):
  9. from transformers import ElectraModel, ElectraTokenizer
  10. weight_path = "google/electra-small-generator"
  11. vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInRoberta".split())
  12. model = ElectraModel.from_pretrained(weight_path)
  13. tokenizer = ElectraTokenizer.from_pretrained(weight_path)
  14. embed = TransformersEmbedding(vocab, model, tokenizer, word_dropout=0.1)
  15. words = torch.LongTensor([[2, 3, 4, 1]])
  16. result = embed(words)
  17. self.assertEqual(result.size(), (1, 4, model.config.hidden_size))
  18. class TransformersWordPieceEncoderTest(unittest.TestCase):
  19. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  20. def test_transformers_embedding_1(self):
  21. from transformers import ElectraModel, ElectraTokenizer
  22. weight_path = "google/electra-small-generator"
  23. model = ElectraModel.from_pretrained(weight_path)
  24. tokenizer = ElectraTokenizer.from_pretrained(weight_path)
  25. encoder = TransformersWordPieceEncoder(model, tokenizer)
  26. ds = DataSet({'words': ["this is a test . [SEP]".split()]})
  27. encoder.index_datasets(ds, field_name='words')
  28. self.assertTrue(ds.has_field('word_pieces'))
  29. result = encoder(torch.LongTensor([[1,2,3,4]]))
  30. self.assertEqual(result.size(), (1, 4, model.config.hidden_size))