You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bert.py 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import unittest
  2. import torch
  3. from fastNLP.core import Vocabulary, Const
  4. from fastNLP.models.bert import BertForSequenceClassification, BertForQuestionAnswering, \
  5. BertForTokenClassification, BertForMultipleChoice, BertForSentenceMatching
  6. from fastNLP.embeddings.bert_embedding import BertEmbedding
  7. class TestBert(unittest.TestCase):
  8. def test_bert_1(self):
  9. vocab = Vocabulary().add_word_lst("this is a test .".split())
  10. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  11. include_cls_sep=True)
  12. model = BertForSequenceClassification(embed, 2)
  13. input_ids = torch.LongTensor([[1, 2, 3], [5, 6, 0]])
  14. pred = model(input_ids)
  15. self.assertTrue(isinstance(pred, dict))
  16. self.assertTrue(Const.OUTPUT in pred)
  17. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
  18. pred = model(input_ids)
  19. self.assertTrue(isinstance(pred, dict))
  20. self.assertTrue(Const.OUTPUT in pred)
  21. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
  22. def test_bert_1_w(self):
  23. vocab = Vocabulary().add_word_lst("this is a test .".split())
  24. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  25. include_cls_sep=False)
  26. with self.assertWarns(Warning):
  27. model = BertForSequenceClassification(embed, 2)
  28. input_ids = torch.LongTensor([[1, 2, 3], [5, 6, 0]])
  29. pred = model.predict(input_ids)
  30. self.assertTrue(isinstance(pred, dict))
  31. self.assertTrue(Const.OUTPUT in pred)
  32. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2,))
  33. def test_bert_2(self):
  34. vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
  35. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  36. include_cls_sep=True)
  37. model = BertForMultipleChoice(embed, 2)
  38. input_ids = torch.LongTensor([[[2, 6, 7], [1, 6, 5]]])
  39. print(input_ids.size())
  40. pred = model(input_ids)
  41. self.assertTrue(isinstance(pred, dict))
  42. self.assertTrue(Const.OUTPUT in pred)
  43. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1, 2))
  44. def test_bert_2_w(self):
  45. vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
  46. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  47. include_cls_sep=False)
  48. with self.assertWarns(Warning):
  49. model = BertForMultipleChoice(embed, 2)
  50. input_ids = torch.LongTensor([[[2, 6, 7], [1, 6, 5]]])
  51. print(input_ids.size())
  52. pred = model.predict(input_ids)
  53. self.assertTrue(isinstance(pred, dict))
  54. self.assertTrue(Const.OUTPUT in pred)
  55. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1,))
  56. def test_bert_3(self):
  57. vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
  58. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  59. include_cls_sep=False)
  60. model = BertForTokenClassification(embed, 7)
  61. input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])
  62. pred = model(input_ids)
  63. self.assertTrue(isinstance(pred, dict))
  64. self.assertTrue(Const.OUTPUT in pred)
  65. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3, 7))
  66. def test_bert_3_w(self):
  67. vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
  68. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  69. include_cls_sep=True)
  70. with self.assertWarns(Warning):
  71. model = BertForTokenClassification(embed, 7)
  72. input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])
  73. pred = model.predict(input_ids)
  74. self.assertTrue(isinstance(pred, dict))
  75. self.assertTrue(Const.OUTPUT in pred)
  76. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3))
  77. def test_bert_4(self):
  78. vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
  79. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  80. include_cls_sep=True)
  81. model = BertForQuestionAnswering(embed)
  82. input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])
  83. pred = model(input_ids)
  84. self.assertTrue(isinstance(pred, dict))
  85. self.assertTrue(Const.OUTPUTS(0) in pred)
  86. self.assertTrue(Const.OUTPUTS(1) in pred)
  87. self.assertEqual(tuple(pred[Const.OUTPUTS(0)].shape), (2, 5))
  88. self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2, 5))
  89. model = BertForQuestionAnswering(embed, 7)
  90. pred = model(input_ids)
  91. self.assertTrue(isinstance(pred, dict))
  92. self.assertEqual(len(pred), 7)
  93. def test_bert_4_w(self):
  94. vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
  95. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  96. include_cls_sep=False)
  97. with self.assertWarns(Warning):
  98. model = BertForQuestionAnswering(embed)
  99. input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])
  100. pred = model.predict(input_ids)
  101. self.assertTrue(isinstance(pred, dict))
  102. self.assertTrue(Const.OUTPUTS(1) in pred)
  103. self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2,))
  104. def test_bert_5(self):
  105. vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
  106. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  107. include_cls_sep=True)
  108. model = BertForSentenceMatching(embed)
  109. input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])
  110. pred = model(input_ids)
  111. self.assertTrue(isinstance(pred, dict))
  112. self.assertTrue(Const.OUTPUT in pred)
  113. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
  114. def test_bert_5_w(self):
  115. vocab = Vocabulary().add_word_lst("this is a test [SEP] .".split())
  116. embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert',
  117. include_cls_sep=False)
  118. with self.assertWarns(Warning):
  119. model = BertForSentenceMatching(embed)
  120. input_ids = torch.LongTensor([[1, 2, 3], [6, 5, 0]])
  121. pred = model.predict(input_ids)
  122. self.assertTrue(isinstance(pred, dict))
  123. self.assertTrue(Const.OUTPUT in pred)
  124. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2,))