You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_other_modules.py 828 B

1234567891011121314151617181920212223242526272829
  1. import unittest
  2. import torch
  3. from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear
  4. class TestGroupNorm(unittest.TestCase):
  5. def test_case_1(self):
  6. gn = GroupNorm(num_features=1, num_groups=10, eps=1.5e-5)
  7. x = torch.randn((20, 50, 10))
  8. y = gn(x)
  9. class TestLayerNormalization(unittest.TestCase):
  10. def test_case_1(self):
  11. ln = LayerNormalization(d_hid=5, eps=2e-3)
  12. x = torch.randn((20, 50, 5))
  13. y = ln(x)
  14. class TestBiLinear(unittest.TestCase):
  15. def test_case_1(self):
  16. bl = BiLinear(n_left=5, n_right=5, n_out=10, bias=True)
  17. x_left = torch.randn((7, 10, 20, 5))
  18. x_right = torch.randn((7, 10, 20, 5))
  19. y = bl(x_left, x_right)
  20. print(bl)
  21. bl2 = BiLinear(n_left=15, n_right=15, n_out=10, bias=True)