You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 706 B

1234567891011121314151617181920
  1. import unittest
  2. import torch
  3. from fastNLP.models import CNNText
  4. from fastNLP.modules.utils import get_dropout_mask, summary
  5. class TestUtil(unittest.TestCase):
  6. def test_get_dropout_mask(self):
  7. tensor = torch.randn(3, 4)
  8. mask = get_dropout_mask(0.3, tensor)
  9. self.assertSequenceEqual(mask.size(), torch.Size([3, 4]))
  10. def test_summary(self):
  11. model = CNNText(embed=(4, 4), num_classes=2, kernel_nums=(9,5), kernel_sizes=(1,3))
  12. # 4 * 4 + 4 * (9 * 1 + 5 * 3) + 2 * (9 + 5 + 1) = 142
  13. self.assertSequenceEqual((142, 142, 0), summary(model))
  14. model.embed.requires_grad = False
  15. self.assertSequenceEqual((142, 126, 16), summary(model))