You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_seq2seq_model.py 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import unittest
  2. from fastNLP.models.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel
  3. from fastNLP import Vocabulary
  4. from fastNLP.embeddings import StaticEmbedding
  5. import torch
  6. from torch import optim
  7. import torch.nn.functional as F
  8. from fastNLP import seq_len_to_mask
  9. def prepare_env():
  10. vocab = Vocabulary().add_word_lst("This is a test .".split())
  11. vocab.add_word_lst("Another test !".split())
  12. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5)
  13. src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]])
  14. tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]])
  15. src_seq_len = torch.LongTensor([3, 2])
  16. tgt_seq_len = torch.LongTensor([4, 2])
  17. return embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len
  18. def train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len):
  19. optimizer = optim.Adam(model.parameters(), lr=1e-2)
  20. mask = seq_len_to_mask(tgt_seq_len).eq(0)
  21. target = tgt_words_idx.masked_fill(mask, -100)
  22. for i in range(100):
  23. optimizer.zero_grad()
  24. pred = model(src_words_idx, tgt_words_idx, src_seq_len)['pred'] # bsz x max_len x vocab_size
  25. loss = F.cross_entropy(pred.transpose(1, 2), target)
  26. loss.backward()
  27. optimizer.step()
  28. right_count = pred.argmax(dim=-1).eq(target).masked_fill(mask, 1).sum()
  29. return right_count
  30. class TestTransformerSeq2SeqModel(unittest.TestCase):
  31. def test_run(self):
  32. # 测试能否跑通
  33. embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env()
  34. for pos_embed in ['learned', 'sin']:
  35. with self.subTest(pos_embed=pos_embed):
  36. model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
  37. pos_embed=pos_embed, max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1,
  38. bind_encoder_decoder_embed=True,
  39. bind_decoder_input_output_embed=True)
  40. output = model(src_words_idx, tgt_words_idx, src_seq_len)
  41. self.assertEqual(output['pred'].size(), (2, 4, len(embed)))
  42. for bind_encoder_decoder_embed in [True, False]:
  43. tgt_embed = None
  44. for bind_decoder_input_output_embed in [True, False]:
  45. if bind_encoder_decoder_embed == False:
  46. tgt_embed = embed
  47. with self.subTest(bind_encoder_decoder_embed=bind_encoder_decoder_embed,
  48. bind_decoder_input_output_embed=bind_decoder_input_output_embed):
  49. model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed,
  50. pos_embed='sin', max_position=20, num_layers=2,
  51. d_model=30, n_head=6, dim_ff=20, dropout=0.1,
  52. bind_encoder_decoder_embed=bind_encoder_decoder_embed,
  53. bind_decoder_input_output_embed=bind_decoder_input_output_embed)
  54. output = model(src_words_idx, tgt_words_idx, src_seq_len)
  55. self.assertEqual(output['pred'].size(), (2, 4, len(embed)))
  56. def test_train(self):
  57. # 测试能否train到overfit
  58. embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env()
  59. model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
  60. pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1,
  61. bind_encoder_decoder_embed=True,
  62. bind_decoder_input_output_embed=True)
  63. right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len)
  64. self.assertEqual(right_count, tgt_words_idx.nelement())
  65. class TestLSTMSeq2SeqModel(unittest.TestCase):
  66. def test_run(self):
  67. # 测试能否跑通
  68. embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env()
  69. for bind_encoder_decoder_embed in [True, False]:
  70. tgt_embed = None
  71. for bind_decoder_input_output_embed in [True, False]:
  72. if bind_encoder_decoder_embed == False:
  73. tgt_embed = embed
  74. with self.subTest(bind_encoder_decoder_embed=bind_encoder_decoder_embed,
  75. bind_decoder_input_output_embed=bind_decoder_input_output_embed):
  76. model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed,
  77. num_layers=2, hidden_size=20, dropout=0.1,
  78. bind_encoder_decoder_embed=bind_encoder_decoder_embed,
  79. bind_decoder_input_output_embed=bind_decoder_input_output_embed)
  80. output = model(src_words_idx, tgt_words_idx, src_seq_len)
  81. self.assertEqual(output['pred'].size(), (2, 4, len(embed)))
  82. def test_train(self):
  83. embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env()
  84. model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
  85. num_layers=1, hidden_size=20, dropout=0.1,
  86. bind_encoder_decoder_embed=True,
  87. bind_decoder_input_output_embed=True)
  88. right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len)
  89. self.assertEqual(right_count, tgt_words_idx.nelement())