You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_other_modules.py 490 B

123456789101112131415
  1. import unittest
  2. import torch
  3. from fastNLP.modules.encoder.star_transformer import StarTransformer
  4. class TestStarTransformer(unittest.TestCase):
  5. def test_1(self):
  6. model = StarTransformer(num_layers=6, hidden_size=100, num_head=8, head_dim=20, max_len=100)
  7. x = torch.rand(16, 45, 100)
  8. mask = torch.ones(16, 45).byte()
  9. y, yn = model(x, mask)
  10. self.assertEqual(tuple(y.size()), (16, 45, 100))
  11. self.assertEqual(tuple(yn.size()), (16, 100))