You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_variational_rnn.py 714 B

7 years ago
7 years ago
7 years ago
7 years ago
12345678910111213141516171819202122232425
  1. import unittest
  2. import numpy as np
  3. import torch
  4. from fastNLP.modules.encoder.variational_rnn import VarLSTM
  5. class TestMaskedRnn(unittest.TestCase):
  6. def test_case_1(self):
  7. masked_rnn = VarLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True)
  8. x = torch.tensor([[[1.0], [2.0]]])
  9. print(x.size())
  10. y = masked_rnn(x)
  11. def test_case_2(self):
  12. input_size = 12
  13. batch = 16
  14. hidden = 10
  15. masked_rnn = VarLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True)
  16. xx = torch.randn((batch, 32, input_size))
  17. y, _ = masked_rnn(xx)
  18. self.assertEqual(tuple(y.shape), (batch, 32, hidden))