You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 298 B

123456789
  1. import unittest
  2. import torch
  3. from fastNLP.modules.utils import get_dropout_mask
  4. class TestUtil(unittest.TestCase):
  5. def test_get_dropout_mask(self):
  6. tensor = torch.randn(3, 4)
  7. mask = get_dropout_mask(0.3, tensor)
  8. self.assertSequenceEqual(mask.size(), torch.Size([3, 4]))