import unittest import torch from fastNLP.modules.generator import SequenceGenerator from fastNLP.modules import TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder, Seq2SeqDecoder, State from fastNLP import Vocabulary from fastNLP.embeddings import StaticEmbedding from torch import nn from fastNLP import seq_len_to_mask def prepare_env(): vocab = Vocabulary().add_word_lst("This is a test .".split()) vocab.add_word_lst("Another test !".split()) embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) encoder_output = torch.randn(2, 3, 10) src_seq_len = torch.LongTensor([3, 2]) encoder_mask = seq_len_to_mask(src_seq_len) return embed, encoder_output, encoder_mask class GreedyDummyDecoder(Seq2SeqDecoder): def __init__(self, decoder_output): super().__init__() self.cur_length = 0 self.decoder_output = decoder_output def decode(self, tokens, state): self.cur_length += 1 scores = self.decoder_output[:, self.cur_length] return scores class DummyState(State): def __init__(self, decoder): super().__init__() self.decoder = decoder def reorder_state(self, indices: torch.LongTensor): self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0) class TestSequenceGenerator(unittest.TestCase): def test_run(self): # 测试能否运行 (1) 初始化decoder,(2) decode一发 embed, encoder_output, encoder_mask = prepare_env() for do_sample in [True, False]: for num_beams in [1, 3, 5]: with self.subTest(do_sample=do_sample, num_beams=num_beams): decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers=1, hidden_size=10, dropout=0.3, bind_decoder_input_output_embed=True, attention=True) state = decoder.init_state(encoder_output, encoder_mask) generator = SequenceGenerator(decoder=decoder, max_length=20, num_beams=num_beams, do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, repetition_penalty=1, length_penalty=1.0, pad_token_id=0) generator.generate(state=state, tokens=None) decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, dropout=0.1, bind_decoder_input_output_embed=True) state = decoder.init_state(encoder_output, encoder_mask) generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, repetition_penalty=1, length_penalty=1.0, pad_token_id=0) generator.generate(state=state, tokens=None) # 测试一下其它值 decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, dropout=0.1, bind_decoder_input_output_embed=True) state = decoder.init_state(encoder_output, encoder_mask) generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, do_sample=do_sample, temperature=0.9, top_k=50, top_p=0.5, bos_token_id=1, eos_token_id=3, repetition_penalty=2, length_penalty=1.5, pad_token_id=0) generator.generate(state=state, tokens=None) def test_greedy_decode(self): # 测试能否正确的generate # greedy for beam_search in [1, 3]: decoder_output = torch.randn(2, 10, 5) path = decoder_output.argmax(dim=-1) # 2 x 10 decoder = GreedyDummyDecoder(decoder_output) with self.subTest(msg=beam_search, beam_search=beam_search): generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, do_sample=False, temperature=1, top_k=50, top_p=1, bos_token_id=1, eos_token_id=None, repetition_penalty=1, length_penalty=1, pad_token_id=0) decode_path = generator.generate(DummyState(decoder), tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) self.assertEqual(decode_path.eq(path).sum(), path.numel()) # greedy check eos_token_id for beam_search in [1, 3]: decoder_output = torch.randn(2, 10, 5) decoder_output[:, :7, 4].fill_(-100) decoder_output[0, 7, 4] = 1000 # 在第8个结束 decoder_output[1, 5, 4] = 1000 path = decoder_output.argmax(dim=-1) # 2 x 4 decoder = GreedyDummyDecoder(decoder_output) with self.subTest(beam_search=beam_search): generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, do_sample=False, temperature=1, top_k=50, top_p=0.5, bos_token_id=1, eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0) decode_path = generator.generate(DummyState(decoder), tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) self.assertEqual(decode_path.size(1), 8) # 长度为8 self.assertEqual(decode_path[0].eq(path[0, :8]).sum(), 8) self.assertEqual(decode_path[1, :6].eq(path[1, :6]).sum(), 6) def test_sample_decoder(self): # greedy check eos_token_id for beam_search in [1, 3]: with self.subTest(beam_search=beam_search): decode_paths = [] # 因为是随机,所以需要测试100次,如果至少有一次是对的,应该就问题不大 num_tests = 10 for i in range(num_tests): decoder_output = torch.randn(2, 10, 5) * 10 decoder_output[:, :7, 4].fill_(-100) decoder_output[0, 7, 4] = 10000 # 在第8个结束 decoder_output[1, 5, 4] = 10000 path = decoder_output.argmax(dim=-1) # 2 x 4 decoder = GreedyDummyDecoder(decoder_output) generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, do_sample=True, temperature=1, top_k=50, top_p=0.5, bos_token_id=1, eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0) decode_path = generator.generate(DummyState(decoder), tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) decode_paths.append([decode_path, path]) sizes = [] eqs = [] eq2s = [] for i in range(num_tests): decode_path, path = decode_paths[i] sizes.append(decode_path.size(1)==8) eqs.append(decode_path[0].eq(path[0, :8]).sum()==8) eq2s.append(decode_path[1, :6].eq(path[1, :6]).sum()==6) self.assertTrue(any(sizes)) self.assertTrue(any(eqs)) self.assertTrue(any(eq2s))