You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_seq2seq_generator.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import unittest
  2. import torch
  3. from fastNLP.modules.generator import SequenceGenerator
  4. from fastNLP.modules import TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder, Seq2SeqDecoder, State
  5. from fastNLP import Vocabulary
  6. from fastNLP.embeddings import StaticEmbedding
  7. from torch import nn
  8. from fastNLP import seq_len_to_mask
  9. def prepare_env():
  10. vocab = Vocabulary().add_word_lst("This is a test .".split())
  11. vocab.add_word_lst("Another test !".split())
  12. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5)
  13. encoder_output = torch.randn(2, 3, 10)
  14. src_seq_len = torch.LongTensor([3, 2])
  15. encoder_mask = seq_len_to_mask(src_seq_len)
  16. return embed, encoder_output, encoder_mask
  17. class TestSequenceGenerator(unittest.TestCase):
  18. def test_run(self):
  19. # 测试能否运行 (1) 初始化decoder,(2) decode一发
  20. embed, encoder_output, encoder_mask = prepare_env()
  21. for do_sample in [True, False]:
  22. for num_beams in [1, 3, 5]:
  23. with self.subTest(do_sample=do_sample, num_beams=num_beams):
  24. decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers=1, hidden_size=10,
  25. dropout=0.3, bind_decoder_input_output_embed=True, attention=True)
  26. state = decoder.init_state(encoder_output, encoder_mask)
  27. generator = SequenceGenerator(decoder=decoder, max_length=20, num_beams=num_beams,
  28. do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None,
  29. repetition_penalty=1, length_penalty=1.0, pad_token_id=0)
  30. generator.generate(state=state, tokens=None)
  31. decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim),
  32. d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, dropout=0.1,
  33. bind_decoder_input_output_embed=True)
  34. state = decoder.init_state(encoder_output, encoder_mask)
  35. generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams,
  36. do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None,
  37. repetition_penalty=1, length_penalty=1.0, pad_token_id=0)
  38. generator.generate(state=state, tokens=None)
  39. # 测试一下其它值
  40. decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim),
  41. d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10,
  42. dropout=0.1,
  43. bind_decoder_input_output_embed=True)
  44. state = decoder.init_state(encoder_output, encoder_mask)
  45. generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams,
  46. do_sample=do_sample, temperature=0.9, top_k=50, top_p=0.5, bos_token_id=1,
  47. eos_token_id=3, repetition_penalty=2, length_penalty=1.5, pad_token_id=0)
  48. generator.generate(state=state, tokens=None)
  49. def test_greedy_decode(self):
  50. # 测试能否正确的generate
  51. class GreedyDummyDecoder(Seq2SeqDecoder):
  52. def __init__(self, decoder_output):
  53. super().__init__()
  54. self.cur_length = 0
  55. self.decoder_output = decoder_output
  56. def decode(self, tokens, state):
  57. self.cur_length += 1
  58. scores = self.decoder_output[:, self.cur_length]
  59. return scores
  60. class DummyState(State):
  61. def __init__(self, decoder):
  62. super().__init__()
  63. self.decoder = decoder
  64. def reorder_state(self, indices: torch.LongTensor):
  65. self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0)
  66. # greedy
  67. for beam_search in [1, 3]:
  68. decoder_output = torch.randn(2, 10, 5)
  69. path = decoder_output.argmax(dim=-1) # 2 x 4
  70. decoder = GreedyDummyDecoder(decoder_output)
  71. with self.subTest(beam_search=beam_search):
  72. generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search,
  73. do_sample=False, temperature=1, top_k=50, top_p=1, bos_token_id=1,
  74. eos_token_id=None, repetition_penalty=1, length_penalty=1, pad_token_id=0)
  75. decode_path = generator.generate(DummyState(decoder), tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True))
  76. self.assertEqual(decode_path.eq(path).sum(), path.numel())
  77. # greedy check eos_token_id
  78. for beam_search in [1, 3]:
  79. decoder_output = torch.randn(2, 10, 5)
  80. decoder_output[:, :7, 4].fill_(-100)
  81. decoder_output[0, 7, 4] = 1000 # 在第8个结束
  82. decoder_output[1, 5, 4] = 1000
  83. path = decoder_output.argmax(dim=-1) # 2 x 4
  84. decoder = GreedyDummyDecoder(decoder_output)
  85. with self.subTest(beam_search=beam_search):
  86. generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search,
  87. do_sample=False, temperature=1, top_k=50, top_p=0.5, bos_token_id=1,
  88. eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0)
  89. decode_path = generator.generate(DummyState(decoder),
  90. tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True))
  91. self.assertEqual(decode_path.size(1), 8) # 长度为8
  92. self.assertEqual(decode_path[0].eq(path[0, :8]).sum(), 8)
  93. self.assertEqual(decode_path[1, :6].eq(path[1, :6]).sum(), 6)