You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_other_modules.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import unittest
  2. import torch
  3. # from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear, BiAffine
  4. from fastNLP.modules.encoder.star_transformer import StarTransformer
  5. class TestGroupNorm(unittest.TestCase):
  6. def test_case_1(self):
  7. gn = GroupNorm(num_features=1, num_groups=10, eps=1.5e-5)
  8. x = torch.randn((20, 50, 10))
  9. y = gn(x)
  10. class TestLayerNormalization(unittest.TestCase):
  11. def test_case_1(self):
  12. ln = LayerNormalization(layer_size=5, eps=2e-3)
  13. x = torch.randn((20, 50, 5))
  14. y = ln(x)
  15. class TestBiLinear(unittest.TestCase):
  16. def test_case_1(self):
  17. bl = BiLinear(n_left=5, n_right=5, n_out=10, bias=True)
  18. x_left = torch.randn((7, 10, 20, 5))
  19. x_right = torch.randn((7, 10, 20, 5))
  20. y = bl(x_left, x_right)
  21. print(bl)
  22. bl2 = BiLinear(n_left=15, n_right=15, n_out=10, bias=True)
  23. class TestBiAffine(unittest.TestCase):
  24. def test_case_1(self):
  25. batch_size = 16
  26. encoder_length = 21
  27. decoder_length = 32
  28. layer = BiAffine(10, 10, 25, biaffine=True)
  29. decoder_input = torch.randn((batch_size, encoder_length, 10))
  30. encoder_input = torch.randn((batch_size, decoder_length, 10))
  31. y = layer(decoder_input, encoder_input)
  32. self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, decoder_length))
  33. def test_case_2(self):
  34. batch_size = 16
  35. encoder_length = 21
  36. decoder_length = 32
  37. layer = BiAffine(10, 10, 25, biaffine=False)
  38. decoder_input = torch.randn((batch_size, encoder_length, 10))
  39. encoder_input = torch.randn((batch_size, decoder_length, 10))
  40. y = layer(decoder_input, encoder_input)
  41. self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, 1))
  42. class TestStarTransformer(unittest.TestCase):
  43. def test_1(self):
  44. model = StarTransformer(num_layers=6, hidden_size=100, num_head=8, head_dim=20, max_len=100)
  45. x = torch.rand(16, 45, 100)
  46. mask = torch.ones(16, 45).byte()
  47. y, yn = model(x, mask)
  48. self.assertEqual(tuple(y.size()), (16, 45, 100))
  49. self.assertEqual(tuple(yn.size()), (16, 100))