You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bert.py 3.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import unittest
  2. import torch
  3. from fastNLP.models.bert import BertForSequenceClassification, BertForQuestionAnswering, \
  4. BertForTokenClassification, BertForMultipleChoice
  5. class TestBert(unittest.TestCase):
  6. def test_bert_1(self):
  7. from fastNLP.core.const import Const
  8. from fastNLP.modules.encoder.bert import BertConfig
  9. model = BertForSequenceClassification(2, BertConfig(32000))
  10. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  11. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  12. pred = model(input_ids, input_mask)
  13. self.assertTrue(isinstance(pred, dict))
  14. self.assertTrue(Const.OUTPUT in pred)
  15. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
  16. input_mask = torch.LongTensor([3, 2])
  17. pred = model(input_ids, input_mask)
  18. self.assertTrue(isinstance(pred, dict))
  19. self.assertTrue(Const.OUTPUT in pred)
  20. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
  21. def test_bert_2(self):
  22. from fastNLP.core.const import Const
  23. from fastNLP.modules.encoder.bert import BertConfig
  24. model = BertForMultipleChoice(2, BertConfig(32000))
  25. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  26. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  27. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  28. pred = model(input_ids, token_type_ids, input_mask)
  29. self.assertTrue(isinstance(pred, dict))
  30. self.assertTrue(Const.OUTPUT in pred)
  31. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1, 2))
  32. def test_bert_3(self):
  33. from fastNLP.core.const import Const
  34. from fastNLP.modules.encoder.bert import BertConfig
  35. model = BertForTokenClassification(7, BertConfig(32000))
  36. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  37. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  38. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  39. pred = model(input_ids, token_type_ids, input_mask)
  40. self.assertTrue(isinstance(pred, dict))
  41. self.assertTrue(Const.OUTPUT in pred)
  42. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3, 7))
  43. def test_bert_4(self):
  44. from fastNLP.core.const import Const
  45. from fastNLP.modules.encoder.bert import BertConfig
  46. model = BertForQuestionAnswering(BertConfig(32000))
  47. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  48. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  49. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  50. pred = model(input_ids, token_type_ids, input_mask)
  51. self.assertTrue(isinstance(pred, dict))
  52. self.assertTrue(Const.OUTPUTS(0) in pred)
  53. self.assertTrue(Const.OUTPUTS(1) in pred)
  54. self.assertEqual(tuple(pred[Const.OUTPUTS(0)].shape), (2, 3))
  55. self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2, 3))