You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_masked_rnn.py 915 B

123456789101112131415161718192021222324252627
  1. import torch
  2. import unittest
  3. from fastNLP.modules.encoder.masked_rnn import MaskedRNN
  4. class TestMaskedRnn(unittest.TestCase):
  5. def test_case_1(self):
  6. masked_rnn = MaskedRNN(input_size=1, hidden_size=1, bidirectional=True, batch_first=True)
  7. x = torch.tensor([[[1.0], [2.0]]])
  8. print(x.size())
  9. y = masked_rnn(x)
  10. mask = torch.tensor([[[1], [1]]])
  11. y = masked_rnn(x, mask=mask)
  12. mask = torch.tensor([[[1], [0]]])
  13. y = masked_rnn(x, mask=mask)
  14. def test_case_2(self):
  15. masked_rnn = MaskedRNN(input_size=1, hidden_size=1, bidirectional=False, batch_first=True)
  16. x = torch.tensor([[[1.0], [2.0]]])
  17. print(x.size())
  18. y = masked_rnn(x)
  19. mask = torch.tensor([[[1], [1]]])
  20. y = masked_rnn(x, mask=mask)
  21. xx = torch.tensor([[[1.0]]])
  22. y = masked_rnn.step(xx)
  23. y = masked_rnn.step(xx, mask=mask)

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等