You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_seq2seq_encoder.py 1.2 kB

123456789101112131415161718192021222324252627282930
  1. import unittest
  2. import torch
  3. from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder
  4. from fastNLP import Vocabulary
  5. from fastNLP.embeddings import StaticEmbedding
  6. class TestTransformerSeq2SeqEncoder(unittest.TestCase):
  7. def test_case(self):
  8. vocab = Vocabulary().add_word_lst("This is a test .".split())
  9. embed = StaticEmbedding(vocab, embedding_dim=5)
  10. encoder = TransformerSeq2SeqEncoder(embed, num_layers=2, d_model=10, n_head=2)
  11. words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0)
  12. seq_len = torch.LongTensor([3])
  13. encoder_output, encoder_mask = encoder(words_idx, seq_len)
  14. self.assertEqual(encoder_output.size(), (1, 3, 10))
  15. class TestBiLSTMEncoder(unittest.TestCase):
  16. def test_case(self):
  17. vocab = Vocabulary().add_word_lst("This is a test .".split())
  18. embed = StaticEmbedding(vocab, embedding_dim=5)
  19. encoder = LSTMSeq2SeqEncoder(embed, hidden_size=5, num_layers=1)
  20. words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0)
  21. seq_len = torch.LongTensor([3])
  22. encoder_output, encoder_mask = encoder(words_idx, seq_len)
  23. self.assertEqual(encoder_mask.size(), (1, 3))