import unittest import torch from fastNLP.modules.encoder.pooling import MaxPool, MaxPoolWithMask, KMaxPool, AvgPool, AvgPoolWithMask class TestPooling(unittest.TestCase): def test_MaxPool(self): max_pool_1d = MaxPool(dimension=1) x = torch.randn(5, 6, 7) self.assertEqual(max_pool_1d(x).size(), (5, 7)) max_pool_2d = MaxPool(dimension=2) self.assertEqual(max_pool_2d(x).size(), (5, 1)) max_pool_3d = MaxPool(dimension=3) x = torch.randn(4, 5, 6, 7) self.assertEqual(max_pool_3d(x).size(), (4, 1, 1)) def test_MaxPoolWithMask(self): pool = MaxPoolWithMask() x = torch.randn(5, 6, 7) mask = (torch.randn(5, 6) > 0).long() self.assertEqual(pool(x, mask).size(), (5, 7)) def test_KMaxPool(self): k_pool = KMaxPool(k=3) x = torch.randn(4, 5, 6) self.assertEqual(k_pool(x).size(), (4, 15)) def test_AvgPool(self): pool = AvgPool() x = torch.randn(4, 5, 6) self.assertEqual(pool(x).size(), (4, 5)) def test_AvgPoolWithMask(self): pool = AvgPoolWithMask() x = torch.randn(5, 6, 7) mask = (torch.randn(5, 6) > 0).long() self.assertEqual(pool(x, mask).size(), (5, 7))