import unittest import torch from fastNLP import Vocabulary from fastNLP.embeddings import StaticEmbedding from fastNLP.modules import TransformerSeq2SeqDecoder from fastNLP.modules import LSTMSeq2SeqDecoder from fastNLP import seq_len_to_mask class TestTransformerSeq2SeqDecoder(unittest.TestCase): def test_case(self): vocab = Vocabulary().add_word_lst("This is a test .".split()) vocab.add_word_lst("Another test !".split()) embed = StaticEmbedding(vocab, embedding_dim=10) encoder_output = torch.randn(2, 3, 10) src_seq_len = torch.LongTensor([3, 2]) encoder_mask = seq_len_to_mask(src_seq_len) for flag in [True, False]: with self.subTest(bind_decoder_input_output_embed=flag): decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed = None, d_model = 10, num_layers=2, n_head = 5, dim_ff = 20, dropout = 0.1, bind_decoder_input_output_embed = True) state = decoder.init_state(encoder_output, encoder_mask) output = decoder(tokens=torch.randint(0, len(vocab), size=(2, 4)), state=state) self.assertEqual(output.size(), (2, 4, len(vocab))) class TestLSTMDecoder(unittest.TestCase): def test_case(self): vocab = Vocabulary().add_word_lst("This is a test .".split()) vocab.add_word_lst("Another test !".split()) embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10) encoder_output = torch.randn(2, 3, 10) tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) src_seq_len = torch.LongTensor([3, 2]) encoder_mask = seq_len_to_mask(src_seq_len) for flag in [True, False]: for attention in [True, False]: with self.subTest(bind_decoder_input_output_embed=flag, attention=attention): decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers = 2, hidden_size = 10, dropout = 0.3, bind_decoder_input_output_embed=flag, attention=attention) state = decoder.init_state(encoder_output, encoder_mask) output = decoder(tgt_words_idx, state) self.assertEqual(tuple(output.size()), (2, 4, len(vocab)))