You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_seq2seq_generator.py 3.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import unittest
  2. from fastNLP.models import SequenceGeneratorModel
  3. from fastNLP.models import LSTMSeq2SeqModel, TransformerSeq2SeqModel
  4. from fastNLP import Vocabulary, DataSet
  5. import torch
  6. from fastNLP.embeddings import StaticEmbedding
  7. from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric
  8. from fastNLP import Callback
  9. def prepare_env():
  10. vocab = Vocabulary().add_word_lst("This is a test .".split())
  11. vocab.add_word_lst("Another test !".split())
  12. embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5)
  13. src_words_idx = [[3, 1, 2], [1, 2]]
  14. # tgt_words_idx = [[1, 2, 3, 4], [2, 3]]
  15. src_seq_len = [3, 2]
  16. # tgt_seq_len = [4, 2]
  17. ds = DataSet({'src_tokens': src_words_idx, 'src_seq_len': src_seq_len, 'tgt_tokens': src_words_idx,
  18. 'tgt_seq_len':src_seq_len})
  19. ds.set_input('src_tokens', 'tgt_tokens', 'src_seq_len')
  20. ds.set_target('tgt_seq_len', 'tgt_tokens')
  21. return embed, ds
  22. class ExitCallback(Callback):
  23. def __init__(self):
  24. super().__init__()
  25. def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
  26. if eval_result['AccuracyMetric']['acc']==1:
  27. raise KeyboardInterrupt()
  28. class TestSeq2SeqGeneratorModel(unittest.TestCase):
  29. def test_run(self):
  30. # 检测是否能够使用SequenceGeneratorModel训练, 透传预测
  31. embed, ds = prepare_env()
  32. model1 = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
  33. pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6,
  34. dim_ff=20, dropout=0.1,
  35. bind_encoder_decoder_embed=True,
  36. bind_decoder_input_output_embed=True)
  37. trainer = Trainer(ds, model1, optimizer=None, loss=CrossEntropyLoss(target='tgt_tokens', seq_len='tgt_seq_len'),
  38. batch_size=32, sampler=None, drop_last=False, update_every=1,
  39. num_workers=0, n_epochs=100, print_every=5,
  40. dev_data=ds, metrics=AccuracyMetric(target='tgt_tokens', seq_len='tgt_seq_len'), metric_key=None,
  41. validate_every=-1, save_path=None, use_tqdm=False, device=None,
  42. callbacks=ExitCallback(), check_code_level=0)
  43. res = trainer.train()
  44. self.assertEqual(res['best_eval']['AccuracyMetric']['acc'], 1)
  45. embed, ds = prepare_env()
  46. model2 = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None,
  47. num_layers=1, hidden_size=20, dropout=0.1,
  48. bind_encoder_decoder_embed=True,
  49. bind_decoder_input_output_embed=True, attention=True)
  50. optimizer = torch.optim.Adam(model2.parameters(), lr=0.01)
  51. trainer = Trainer(ds, model2, optimizer=optimizer, loss=CrossEntropyLoss(target='tgt_tokens', seq_len='tgt_seq_len'),
  52. batch_size=32, sampler=None, drop_last=False, update_every=1,
  53. num_workers=0, n_epochs=200, print_every=1,
  54. dev_data=ds, metrics=AccuracyMetric(target='tgt_tokens', seq_len='tgt_seq_len'),
  55. metric_key=None,
  56. validate_every=-1, save_path=None, use_tqdm=False, device=None,
  57. callbacks=ExitCallback(), check_code_level=0)
  58. res = trainer.train()
  59. self.assertEqual(res['best_eval']['AccuracyMetric']['acc'], 1)