You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bert.py 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import unittest
  2. import torch
  3. from fastNLP.core import Vocabulary, Const
  4. from fastNLP.models.bert import BertForSequenceClassification, BertForQuestionAnswering, \
  5. BertForTokenClassification, BertForMultipleChoice, BertForSentenceMatching
  6. from fastNLP.embeddings.bert_embedding import BertEmbedding
  7. class TestBert(unittest.TestCase):
  8. def test_bert_1(self):
  9. vocab = Vocabulary().add_word_lst("this is a test .".split())
  10. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  11. include_cls_sep=True)
  12. model = BertForSequenceClassification(embed, 2)
  13. input_ids = torch.LongTensor([[1, 2, 3], [5, 6, 0]])
  14. pred = model(input_ids)
  15. self.assertTrue(isinstance(pred, dict))
  16. self.assertTrue(Const.OUTPUT in pred)
  17. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
  18. pred = model(input_ids)
  19. self.assertTrue(isinstance(pred, dict))
  20. self.assertTrue(Const.OUTPUT in pred)
  21. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
  22. def test_bert_1_w(self):
  23. vocab = Vocabulary().add_word_lst("this is a test .".split())
  24. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  25. include_cls_sep=False)
  26. with self.assertWarns(Warning):
  27. model = BertForSequenceClassification(embed, 2)
  28. input_ids = torch.LongTensor([[1, 2, 3], [5, 6, 0]])
  29. pred = model.predict(input_ids)
  30. self.assertTrue(isinstance(pred, dict))
  31. self.assertTrue(Const.OUTPUT in pred)
  32. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2,))
  33. def test_bert_2(self):
  34. vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
  35. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  36. include_cls_sep=True)
  37. model = BertForMultipleChoice(embed, 2)
  38. input_ids = torch.LongTensor([[[2, 6, 7], [1, 6, 5]]])
  39. print(input_ids.size())
  40. pred = model(input_ids)
  41. self.assertTrue(isinstance(pred, dict))
  42. self.assertTrue(Const.OUTPUT in pred)
  43. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1, 2))
  44. def test_bert_2_w(self):
  45. vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
  46. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  47. include_cls_sep=False)
  48. with self.assertWarns(Warning):
  49. model = BertForMultipleChoice(embed, 2)
  50. input_ids = torch.LongTensor([[[2, 6, 7], [1, 6, 5]]])
  51. print(input_ids.size())
  52. pred = model.predict(input_ids)
  53. self.assertTrue(isinstance(pred, dict))
  54. self.assertTrue(Const.OUTPUT in pred)
  55. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1,))
  56. def test_bert_3(self):
  57. vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
  58. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  59. include_cls_sep=False)
  60. model = BertForTokenClassification(embed, 7)
  61. input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])
  62. pred = model(input_ids)
  63. self.assertTrue(isinstance(pred, dict))
  64. self.assertTrue(Const.OUTPUT in pred)
  65. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3, 7))
  66. def test_bert_3_w(self):
  67. vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
  68. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  69. include_cls_sep=True)
  70. with self.assertWarns(Warning):
  71. model = BertForTokenClassification(embed, 7)
  72. input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])
  73. pred = model.predict(input_ids)
  74. self.assertTrue(isinstance(pred, dict))
  75. self.assertTrue(Const.OUTPUT in pred)
  76. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3))
  77. def test_bert_4(self):
  78. vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
  79. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  80. include_cls_sep=False)
  81. model = BertForQuestionAnswering(embed)
  82. input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])
  83. pred = model(input_ids)
  84. self.assertTrue(isinstance(pred, dict))
  85. self.assertTrue('pred_start' in pred)
  86. self.assertTrue('pred_end' in pred)
  87. self.assertEqual(tuple(pred['pred_start'].shape), (2, 3))
  88. self.assertEqual(tuple(pred['pred_end'].shape), (2, 3))
  89. def test_bert_for_question_answering_train(self):
  90. from fastNLP import CMRC2018Loss
  91. from fastNLP.io import CMRC2018BertPipe
  92. from fastNLP import Trainer
  93. data_bundle = CMRC2018BertPipe().process_from_file('test/data_for_tests/io/cmrc')
  94. data_bundle.rename_field('chars', 'words')
  95. train_data = data_bundle.get_dataset('train')
  96. vocab = data_bundle.get_vocab('words')
  97. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  98. include_cls_sep=False, auto_truncate=True)
  99. model = BertForQuestionAnswering(embed)
  100. loss = CMRC2018Loss()
  101. trainer = Trainer(train_data, model, loss=loss, use_tqdm=False)
  102. trainer.train(load_best_model=False)
  103. def test_bert_5(self):
  104. vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
  105. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  106. include_cls_sep=True)
  107. model = BertForSentenceMatching(embed)
  108. input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])
  109. pred = model(input_ids)
  110. self.assertTrue(isinstance(pred, dict))
  111. self.assertTrue(Const.OUTPUT in pred)
  112. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
  113. def test_bert_5_w(self):
  114. vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
  115. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  116. include_cls_sep=False)
  117. with self.assertWarns(Warning):
  118. model = BertForSentenceMatching(embed)
  119. input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])
  120. pred = model.predict(input_ids)
  121. self.assertTrue(isinstance(pred, dict))
  122. self.assertTrue(Const.OUTPUT in pred)
  123. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2,))