You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bert_embedding.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import unittest
  2. from fastNLP import Vocabulary
  3. from fastNLP.embeddings import BertEmbedding, BertWordPieceEncoder
  4. import torch
  5. import os
  6. from fastNLP import DataSet
  7. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  8. class TestDownload(unittest.TestCase):
  9. def test_download(self):
  10. # import os
  11. vocab = Vocabulary().add_word_lst("This is a test .".split())
  12. embed = BertEmbedding(vocab, model_dir_or_name='en')
  13. words = torch.LongTensor([[2, 3, 4, 0]])
  14. print(embed(words).size())
  15. for pool_method in ['first', 'last', 'max', 'avg']:
  16. for include_cls_sep in [True, False]:
  17. embed = BertEmbedding(vocab, model_dir_or_name='en', pool_method=pool_method,
  18. include_cls_sep=include_cls_sep)
  19. print(embed(words).size())
  20. def test_word_drop(self):
  21. vocab = Vocabulary().add_word_lst("This is a test .".split())
  22. embed = BertEmbedding(vocab, model_dir_or_name='en', dropout=0.1, word_dropout=0.2)
  23. for i in range(10):
  24. words = torch.LongTensor([[2, 3, 4, 0]])
  25. print(embed(words).size())
  26. class TestBertEmbedding(unittest.TestCase):
  27. def test_bert_embedding_1(self):
  28. vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split())
  29. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1)
  30. requires_grad = embed.requires_grad
  31. embed.requires_grad = not requires_grad
  32. embed.train()
  33. words = torch.LongTensor([[2, 3, 4, 0]])
  34. result = embed(words)
  35. self.assertEqual(result.size(), (1, 4, 16))
  36. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1,
  37. only_use_pretrain_bpe=True)
  38. embed.eval()
  39. words = torch.LongTensor([[2, 3, 4, 0]])
  40. result = embed(words)
  41. self.assertEqual(result.size(), (1, 4, 16))
  42. # 自动截断而不报错
  43. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1,
  44. only_use_pretrain_bpe=True, auto_truncate=True)
  45. words = torch.LongTensor([[2, 3, 4, 1]*10,
  46. [2, 3]+[0]*38])
  47. result = embed(words)
  48. self.assertEqual(result.size(), (2, 40, 16))
  49. def test_bert_embedding_2(self):
  50. # 测试only_use_pretrain_vocab与truncate_embed是否正常工作
  51. with open('test/data_for_tests/embedding/small_bert/vocab.txt', 'r', encoding='utf-8') as f:
  52. num_word = len(f.readlines())
  53. Embedding = BertEmbedding
  54. vocab = Vocabulary().add_word_lst("this is a texta and [SEP] NotInBERT".split())
  55. embed1 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  56. only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1)
  57. embed_bpe_vocab_size = len(vocab)-1 + 2 # 排除NotInBERT, 额外加##a, [CLS]
  58. self.assertEqual(embed_bpe_vocab_size, len(embed1.model.tokenzier.vocab))
  59. embed2 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  60. only_use_pretrain_bpe=True, truncate_embed=False, min_freq=1)
  61. embed_bpe_vocab_size = num_word # 排除NotInBERT
  62. self.assertEqual(embed_bpe_vocab_size, len(embed2.model.tokenzier.vocab))
  63. embed3 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  64. only_use_pretrain_bpe=False, truncate_embed=True, min_freq=1)
  65. embed_bpe_vocab_size = len(vocab)+2 # 新增##a, [CLS]
  66. self.assertEqual(embed_bpe_vocab_size, len(embed3.model.tokenzier.vocab))
  67. embed4 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  68. only_use_pretrain_bpe=False, truncate_embed=False, min_freq=1)
  69. embed_bpe_vocab_size = num_word+1 # 新增##a
  70. self.assertEqual(embed_bpe_vocab_size, len(embed4.model.tokenzier.vocab))
  71. # 测试各种情况下以下tensor的值是相等的
  72. embed1.eval()
  73. embed2.eval()
  74. embed3.eval()
  75. embed4.eval()
  76. tensor = torch.LongTensor([[vocab.to_index(w) for w in 'this is a texta and'.split()]])
  77. t1 = embed1(tensor)
  78. t2 = embed2(tensor)
  79. t3 = embed3(tensor)
  80. t4 = embed4(tensor)
  81. self.assertEqual((t1-t2).sum(), 0)
  82. self.assertEqual((t1-t3).sum(), 0)
  83. self.assertEqual((t1-t4).sum(), 0)
  84. class TestBertWordPieceEncoder(unittest.TestCase):
  85. def test_bert_word_piece_encoder(self):
  86. embed = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1)
  87. ds = DataSet({'words': ["this is a test . [SEP]".split()]})
  88. embed.index_datasets(ds, field_name='words')
  89. self.assertTrue(ds.has_field('word_pieces'))
  90. result = embed(torch.LongTensor([[1,2,3,4]]))
  91. def test_bert_embed_eq_bert_piece_encoder(self):
  92. ds = DataSet({'words': ["this is a texta model vocab".split(), 'this is'.split()]})
  93. encoder = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert')
  94. encoder.eval()
  95. encoder.index_datasets(ds, field_name='words')
  96. word_pieces = torch.LongTensor(ds['word_pieces'].get([0, 1]))
  97. word_pieces_res = encoder(word_pieces)
  98. vocab = Vocabulary()
  99. vocab.from_dataset(ds, field_name='words')
  100. vocab.index_dataset(ds, field_name='words', new_field_name='words')
  101. ds.set_input('words')
  102. words = torch.LongTensor(ds['words'].get([0, 1]))
  103. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  104. pool_method='first', include_cls_sep=True, pooled_cls=False)
  105. embed.eval()
  106. words_res = embed(words)
  107. # 检查word piece什么的是正常work的
  108. self.assertEqual((word_pieces_res[0, :5]-words_res[0, :5]).sum(), 0)
  109. self.assertEqual((word_pieces_res[0, 6:]-words_res[0, 5:]).sum(), 0)
  110. self.assertEqual((word_pieces_res[1, :3]-words_res[1, :3]).sum(), 0)