import unittest from fastNLP import Vocabulary from fastNLP.embeddings import BertEmbedding, BertWordPieceEncoder import torch import os from fastNLP import DataSet @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") class TestDownload(unittest.TestCase): def test_download(self): # import os vocab = Vocabulary().add_word_lst("This is a test .".split()) embed = BertEmbedding(vocab, model_dir_or_name='en') words = torch.LongTensor([[2, 3, 4, 0]]) print(embed(words).size()) for pool_method in ['first', 'last', 'max', 'avg']: for include_cls_sep in [True, False]: embed = BertEmbedding(vocab, model_dir_or_name='en', pool_method=pool_method, include_cls_sep=include_cls_sep) print(embed(words).size()) def test_word_drop(self): vocab = Vocabulary().add_word_lst("This is a test .".split()) embed = BertEmbedding(vocab, model_dir_or_name='en', dropout=0.1, word_dropout=0.2) for i in range(10): words = torch.LongTensor([[2, 3, 4, 0]]) print(embed(words).size()) class TestBertEmbedding(unittest.TestCase): def test_bert_embedding_1(self): vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) requires_grad = embed.requires_grad embed.requires_grad = not requires_grad embed.train() words = torch.LongTensor([[2, 3, 4, 0]]) result = embed(words) self.assertEqual(result.size(), (1, 4, 16)) embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, only_use_pretrain_bpe=True) embed.eval() words = torch.LongTensor([[2, 3, 4, 0]]) result = embed(words) self.assertEqual(result.size(), (1, 4, 16)) # 自动截断而不报错 embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, only_use_pretrain_bpe=True, auto_truncate=True) words = torch.LongTensor([[2, 3, 4, 1]*10, [2, 3]+[0]*38]) result = embed(words) self.assertEqual(result.size(), (2, 40, 16)) def test_bert_embedding_2(self): # 测试only_use_pretrain_vocab与truncate_embed是否正常工作 with open('test/data_for_tests/embedding/small_bert/vocab.txt', 'r', encoding='utf-8') as f: num_word = len(f.readlines()) Embedding = BertEmbedding vocab = Vocabulary().add_word_lst("this is a texta and [SEP] NotInBERT".split()) embed1 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1) embed_bpe_vocab_size = len(vocab)-1 + 2 # 排除NotInBERT, 额外加##a, [CLS] self.assertEqual(embed_bpe_vocab_size, len(embed1.model.tokenzier.vocab)) embed2 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', only_use_pretrain_bpe=True, truncate_embed=False, min_freq=1) embed_bpe_vocab_size = num_word # 排除NotInBERT self.assertEqual(embed_bpe_vocab_size, len(embed2.model.tokenzier.vocab)) embed3 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', only_use_pretrain_bpe=False, truncate_embed=True, min_freq=1) embed_bpe_vocab_size = len(vocab)+2 # 新增##a, [CLS] self.assertEqual(embed_bpe_vocab_size, len(embed3.model.tokenzier.vocab)) embed4 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', only_use_pretrain_bpe=False, truncate_embed=False, min_freq=1) embed_bpe_vocab_size = num_word+1 # 新增##a self.assertEqual(embed_bpe_vocab_size, len(embed4.model.tokenzier.vocab)) # 测试各种情况下以下tensor的值是相等的 embed1.eval() embed2.eval() embed3.eval() embed4.eval() tensor = torch.LongTensor([[vocab.to_index(w) for w in 'this is a texta and'.split()]]) t1 = embed1(tensor) t2 = embed2(tensor) t3 = embed3(tensor) t4 = embed4(tensor) self.assertEqual((t1-t2).sum(), 0) self.assertEqual((t1-t3).sum(), 0) self.assertEqual((t1-t4).sum(), 0) class TestBertWordPieceEncoder(unittest.TestCase): def test_bert_word_piece_encoder(self): embed = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) ds = DataSet({'words': ["this is a test . [SEP]".split()]}) embed.index_datasets(ds, field_name='words') self.assertTrue(ds.has_field('word_pieces')) result = embed(torch.LongTensor([[1,2,3,4]])) def test_bert_embed_eq_bert_piece_encoder(self): ds = DataSet({'words': ["this is a texta model vocab".split(), 'this is'.split()]}) encoder = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert') encoder.eval() encoder.index_datasets(ds, field_name='words') word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1])) word_pieces_res = encoder(word_pieces) vocab = Vocabulary() vocab.from_dataset(ds, field_name='words') vocab.index_dataset(ds, field_name='words', new_field_name='words') ds.set_input('words') words = torch.LongTensor(ds['words'].get([0, 1])) embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', pool_method='first', include_cls_sep=True, pooled_cls=False) embed.eval() words_res = embed(words) # 检查word piece什么的是正常work的 self.assertEqual((word_pieces_res[0, :5]-words_res[0, :5]).sum(), 0) self.assertEqual((word_pieces_res[0, 6:]-words_res[0, 5:]).sum(), 0) self.assertEqual((word_pieces_res[1, :3]-words_res[1, :3]).sum(), 0)