You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_other_modules.py 1.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import unittest
  2. import torch
  3. from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear, BiAffine
  4. class TestGroupNorm(unittest.TestCase):
  5. def test_case_1(self):
  6. gn = GroupNorm(num_features=1, num_groups=10, eps=1.5e-5)
  7. x = torch.randn((20, 50, 10))
  8. y = gn(x)
  9. class TestLayerNormalization(unittest.TestCase):
  10. def test_case_1(self):
  11. ln = LayerNormalization(layer_size=5, eps=2e-3)
  12. x = torch.randn((20, 50, 5))
  13. y = ln(x)
  14. class TestBiLinear(unittest.TestCase):
  15. def test_case_1(self):
  16. bl = BiLinear(n_left=5, n_right=5, n_out=10, bias=True)
  17. x_left = torch.randn((7, 10, 20, 5))
  18. x_right = torch.randn((7, 10, 20, 5))
  19. y = bl(x_left, x_right)
  20. print(bl)
  21. bl2 = BiLinear(n_left=15, n_right=15, n_out=10, bias=True)
  22. class TestBiAffine(unittest.TestCase):
  23. def test_case_1(self):
  24. batch_size = 16
  25. encoder_length = 21
  26. decoder_length = 32
  27. layer = BiAffine(10, 10, 25, biaffine=True)
  28. decoder_input = torch.randn((batch_size, encoder_length, 10))
  29. encoder_input = torch.randn((batch_size, decoder_length, 10))
  30. y = layer(decoder_input, encoder_input)
  31. self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, decoder_length))
  32. def test_case_2(self):
  33. batch_size = 16
  34. encoder_length = 21
  35. decoder_length = 32
  36. layer = BiAffine(10, 10, 25, biaffine=False)
  37. decoder_input = torch.randn((batch_size, encoder_length, 10))
  38. encoder_input = torch.randn((batch_size, decoder_length, 10))
  39. y = layer(decoder_input, encoder_input)
  40. self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, 1))