You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_pooling.py 1.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import unittest
  2. import torch
  3. from fastNLP.modules.encoder.pooling import MaxPool, MaxPoolWithMask, KMaxPool, AvgPool, AvgPoolWithMask
  4. class TestPooling(unittest.TestCase):
  5. def test_MaxPool(self):
  6. max_pool_1d = MaxPool(dimension=1)
  7. x = torch.randn(5, 6, 7)
  8. self.assertEqual(max_pool_1d(x).size(), (5, 7))
  9. max_pool_2d = MaxPool(dimension=2)
  10. self.assertEqual(max_pool_2d(x).size(), (5, 1))
  11. max_pool_3d = MaxPool(dimension=3)
  12. x = torch.randn(4, 5, 6, 7)
  13. self.assertEqual(max_pool_3d(x).size(), (4, 1, 1))
  14. def test_MaxPoolWithMask(self):
  15. pool = MaxPoolWithMask()
  16. x = torch.randn(5, 6, 7)
  17. mask = (torch.randn(5, 6) > 0).long()
  18. self.assertEqual(pool(x, mask).size(), (5, 7))
  19. def test_KMaxPool(self):
  20. k_pool = KMaxPool(k=3)
  21. x = torch.randn(4, 5, 6)
  22. self.assertEqual(k_pool(x).size(), (4, 15))
  23. def test_AvgPool(self):
  24. pool = AvgPool()
  25. x = torch.randn(4, 5, 6)
  26. self.assertEqual(pool(x).size(), (4, 5))
  27. def test_AvgPoolWithMask(self):
  28. pool = AvgPoolWithMask()
  29. x = torch.randn(5, 6, 7)
  30. mask = (torch.randn(5, 6) > 0).long()
  31. self.assertEqual(pool(x, mask).size(), (5, 7))