You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_seq2seq_decoder.py 2.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import unittest
  2. import torch
  3. from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder
  4. from fastNLP.modules.decoder.seq2seq_decoder import TransformerSeq2SeqDecoder, TransformerPast, LSTMPast, \
  5. LSTMSeq2SeqDecoder
  6. from fastNLP.models.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel
  7. from fastNLP import Vocabulary
  8. from fastNLP.embeddings import StaticEmbedding
  9. from fastNLP.core.utils import seq_len_to_mask
  10. class TestTransformerSeq2SeqDecoder(unittest.TestCase):
  11. def test_case(self):
  12. vocab = Vocabulary().add_word_lst("This is a test .".split())
  13. vocab.add_word_lst("Another test !".split())
  14. embed = StaticEmbedding(vocab, embedding_dim=512)
  15. args = TransformerSeq2SeqModel.add_args()
  16. model = TransformerSeq2SeqModel.build_model(args, vocab, vocab)
  17. src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
  18. tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
  19. src_seq_len = torch.LongTensor([3, 2])
  20. output = model(src_words_idx, src_seq_len, tgt_words_idx)
  21. print(output)
  22. # self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab)))
  23. def test_decode(self):
  24. pass # todo
  25. class TestLSTMDecoder(unittest.TestCase):
  26. def test_case(self):
  27. vocab = Vocabulary().add_word_lst("This is a test .".split())
  28. vocab.add_word_lst("Another test !".split())
  29. embed = StaticEmbedding(vocab, embedding_dim=512)
  30. encoder = BiLSTMEncoder(embed)
  31. decoder = LSTMDecoder(embed, bind_input_output_embed=True)
  32. src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
  33. tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
  34. src_seq_len = torch.LongTensor([3, 2])
  35. words, hx = encoder(src_words_idx, src_seq_len)
  36. encode_mask = seq_len_to_mask(src_seq_len)
  37. hidden = torch.cat([hx[0][-2:-1], hx[0][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1)
  38. cell = torch.cat([hx[1][-2:-1], hx[1][-1:]], dim=-1).repeat(decoder.num_layers, 1, 1)
  39. past = LSTMPast(encode_outputs=words, encode_mask=encode_mask, hx=(hidden, cell))
  40. decoder_outputs = decoder(tgt_words_idx, past)
  41. print(decoder_outputs)
  42. print(encode_mask)
  43. self.assertEqual(tuple(decoder_outputs.size()), (2, 4, len(vocab)))
  44. def test_decode(self):
  45. pass # todo