import unittest import torch import os from fastNLP import DataSet, Vocabulary from fastNLP.embeddings.transformers_embedding import TransformersEmbedding, TransformersWordPieceEncoder class TransformersEmbeddingTest(unittest.TestCase): @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") def test_transformers_embedding_1(self): from transformers import ElectraModel, ElectraTokenizer weight_path = "google/electra-small-generator" vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInRoberta".split()) model = ElectraModel.from_pretrained(weight_path) tokenizer = ElectraTokenizer.from_pretrained(weight_path) embed = TransformersEmbedding(vocab, model, tokenizer, word_dropout=0.1) words = torch.LongTensor([[2, 3, 4, 1]]) result = embed(words) self.assertEqual(result.size(), (1, 4, model.config.hidden_size)) class TransformersWordPieceEncoderTest(unittest.TestCase): @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") def test_transformers_embedding_1(self): from transformers import ElectraModel, ElectraTokenizer weight_path = "google/electra-small-generator" model = ElectraModel.from_pretrained(weight_path) tokenizer = ElectraTokenizer.from_pretrained(weight_path) encoder = TransformersWordPieceEncoder(model, tokenizer) ds = DataSet({'words': ["this is a test . [SEP]".split()]}) encoder.index_datasets(ds, field_name='words') self.assertTrue(ds.has_field('word_pieces')) result = encoder(torch.LongTensor([[1,2,3,4]])) self.assertEqual(result.size(), (1, 4, model.config.hidden_size))