You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bert.py 2.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import unittest
  2. import torch
  3. from fastNLP.models.bert import *
  4. class TestBert(unittest.TestCase):
  5. def test_bert_1(self):
  6. from fastNLP.core.const import Const
  7. model = BertForSequenceClassification(2)
  8. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  9. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  10. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  11. pred = model(input_ids, token_type_ids, input_mask)
  12. self.assertTrue(isinstance(pred, dict))
  13. self.assertTrue(Const.OUTPUT in pred)
  14. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
  15. def test_bert_2(self):
  16. from fastNLP.core.const import Const
  17. model = BertForMultipleChoice(2)
  18. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  19. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  20. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  21. pred = model(input_ids, token_type_ids, input_mask)
  22. self.assertTrue(isinstance(pred, dict))
  23. self.assertTrue(Const.OUTPUT in pred)
  24. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1, 2))
  25. def test_bert_3(self):
  26. from fastNLP.core.const import Const
  27. model = BertForTokenClassification(7)
  28. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  29. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  30. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  31. pred = model(input_ids, token_type_ids, input_mask)
  32. self.assertTrue(isinstance(pred, dict))
  33. self.assertTrue(Const.OUTPUT in pred)
  34. self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3, 7))
  35. def test_bert_4(self):
  36. from fastNLP.core.const import Const
  37. model = BertForQuestionAnswering()
  38. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  39. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  40. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  41. pred = model(input_ids, token_type_ids, input_mask)
  42. self.assertTrue(isinstance(pred, dict))
  43. self.assertTrue(Const.OUTPUTS(0) in pred)
  44. self.assertTrue(Const.OUTPUTS(1) in pred)
  45. self.assertEqual(tuple(pred[Const.OUTPUTS(0)].shape), (2, 3))
  46. self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2, 3))