You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_seq2seq_generator.py 8.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import unittest
  2. import torch
  3. from fastNLP.modules.generator import SequenceGenerator
  4. from fastNLP.modules import TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder, Seq2SeqDecoder, State
  5. from fastNLP import Vocabulary
  6. from fastNLP.embeddings import StaticEmbedding
  7. from torch import nn
  8. from fastNLP import seq_len_to_mask
  9. def prepare_env():
  10. vocab = Vocabulary().add_word_lst("This is a test .".split())
  11. vocab.add_word_lst("Another test !".split())
  12. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5)
  13. encoder_output = torch.randn(2, 3, 10)
  14. src_seq_len = torch.LongTensor([3, 2])
  15. encoder_mask = seq_len_to_mask(src_seq_len)
  16. return embed, encoder_output, encoder_mask
  17. class GreedyDummyDecoder(Seq2SeqDecoder):
  18. def __init__(self, decoder_output):
  19. super().__init__()
  20. self.cur_length = 0
  21. self.decoder_output = decoder_output
  22. def decode(self, tokens, state):
  23. self.cur_length += 1
  24. scores = self.decoder_output[:, self.cur_length]
  25. return scores
  26. class DummyState(State):
  27. def __init__(self, decoder):
  28. super().__init__()
  29. self.decoder = decoder
  30. def reorder_state(self, indices: torch.LongTensor):
  31. self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0)
  32. class TestSequenceGenerator(unittest.TestCase):
  33. def test_run(self):
  34. # 测试能否运行 (1) 初始化decoder,(2) decode一发
  35. embed, encoder_output, encoder_mask = prepare_env()
  36. for do_sample in [True, False]:
  37. for num_beams in [1, 3, 5]:
  38. with self.subTest(do_sample=do_sample, num_beams=num_beams):
  39. decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers=1, hidden_size=10,
  40. dropout=0.3, bind_decoder_input_output_embed=True, attention=True)
  41. state = decoder.init_state(encoder_output, encoder_mask)
  42. generator = SequenceGenerator(decoder=decoder, max_length=20, num_beams=num_beams,
  43. do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None,
  44. repetition_penalty=1, length_penalty=1.0, pad_token_id=0)
  45. generator.generate(state=state, tokens=None)
  46. decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim),
  47. d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, dropout=0.1,
  48. bind_decoder_input_output_embed=True)
  49. state = decoder.init_state(encoder_output, encoder_mask)
  50. generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams,
  51. do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None,
  52. repetition_penalty=1, length_penalty=1.0, pad_token_id=0)
  53. generator.generate(state=state, tokens=None)
  54. # 测试一下其它值
  55. decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim),
  56. d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10,
  57. dropout=0.1,
  58. bind_decoder_input_output_embed=True)
  59. state = decoder.init_state(encoder_output, encoder_mask)
  60. generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams,
  61. do_sample=do_sample, temperature=0.9, top_k=50, top_p=0.5, bos_token_id=1,
  62. eos_token_id=3, repetition_penalty=2, length_penalty=1.5, pad_token_id=0)
  63. generator.generate(state=state, tokens=None)
  64. def test_greedy_decode(self):
  65. # 测试能否正确的generate
  66. # greedy
  67. for beam_search in [1, 3]:
  68. decoder_output = torch.randn(2, 10, 5)
  69. path = decoder_output.argmax(dim=-1) # 2 x 10
  70. decoder = GreedyDummyDecoder(decoder_output)
  71. with self.subTest(msg=beam_search, beam_search=beam_search):
  72. generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search,
  73. do_sample=False, temperature=1, top_k=50, top_p=1, bos_token_id=1,
  74. eos_token_id=None, repetition_penalty=1, length_penalty=1, pad_token_id=0)
  75. decode_path = generator.generate(DummyState(decoder), tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True))
  76. self.assertEqual(decode_path.eq(path).sum(), path.numel())
  77. # greedy check eos_token_id
  78. for beam_search in [1, 3]:
  79. decoder_output = torch.randn(2, 10, 5)
  80. decoder_output[:, :7, 4].fill_(-100)
  81. decoder_output[0, 7, 4] = 1000 # 在第8个结束
  82. decoder_output[1, 5, 4] = 1000
  83. path = decoder_output.argmax(dim=-1) # 2 x 4
  84. decoder = GreedyDummyDecoder(decoder_output)
  85. with self.subTest(beam_search=beam_search):
  86. generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search,
  87. do_sample=False, temperature=1, top_k=50, top_p=0.5, bos_token_id=1,
  88. eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0)
  89. decode_path = generator.generate(DummyState(decoder),
  90. tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True))
  91. self.assertEqual(decode_path.size(1), 8) # 长度为8
  92. self.assertEqual(decode_path[0].eq(path[0, :8]).sum(), 8)
  93. self.assertEqual(decode_path[1, :6].eq(path[1, :6]).sum(), 6)
  94. def test_sample_decoder(self):
  95. # greedy check eos_token_id
  96. for beam_search in [1, 3]:
  97. with self.subTest(beam_search=beam_search):
  98. decode_paths = []
  99. # 因为是随机,所以需要测试100次,如果至少有一次是对的,应该就问题不大
  100. num_tests = 10
  101. for i in range(num_tests):
  102. decoder_output = torch.randn(2, 10, 5) * 10
  103. decoder_output[:, :7, 4].fill_(-100)
  104. decoder_output[0, 7, 4] = 10000 # 在第8个结束
  105. decoder_output[1, 5, 4] = 10000
  106. path = decoder_output.argmax(dim=-1) # 2 x 4
  107. decoder = GreedyDummyDecoder(decoder_output)
  108. generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search,
  109. do_sample=True, temperature=1, top_k=50, top_p=0.5, bos_token_id=1,
  110. eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0)
  111. decode_path = generator.generate(DummyState(decoder),
  112. tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True))
  113. decode_paths.append([decode_path, path])
  114. sizes = []
  115. eqs = []
  116. eq2s = []
  117. for i in range(num_tests):
  118. decode_path, path = decode_paths[i]
  119. sizes.append(decode_path.size(1)==8)
  120. eqs.append(decode_path[0].eq(path[0, :8]).sum()==8)
  121. eq2s.append(decode_path[1, :6].eq(path[1, :6]).sum()==6)
  122. self.assertTrue(any(sizes))
  123. self.assertTrue(any(eqs))
  124. self.assertTrue(any(eq2s))