You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Translator.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. ''' This module will handle the text generation with beam search. '''
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from transformer.Models import Transformer
  6. from transformer.Beam import Beam
  7. class Translator(object):
  8. ''' Load with trained model and handle the beam search '''
  9. def __init__(self, opt):
  10. self.opt = opt
  11. self.device = torch.device('cuda' if opt.cuda else 'cpu')
  12. checkpoint = torch.load(opt.model)
  13. model_opt = checkpoint['settings']
  14. self.model_opt = model_opt
  15. model = Transformer(
  16. model_opt.src_vocab_size,
  17. model_opt.tgt_vocab_size,
  18. model_opt.max_token_seq_len,
  19. tgt_emb_prj_weight_sharing=model_opt.proj_share_weight,
  20. emb_src_tgt_weight_sharing=model_opt.embs_share_weight,
  21. d_k=model_opt.d_k,
  22. d_v=model_opt.d_v,
  23. d_model=model_opt.d_model,
  24. d_word_vec=model_opt.d_word_vec,
  25. d_inner=model_opt.d_inner_hid,
  26. n_layers=model_opt.n_layers,
  27. n_head=model_opt.n_head,
  28. dropout=model_opt.dropout)
  29. model.load_state_dict(checkpoint['model'])
  30. print('[Info] Trained model state loaded.')
  31. model.word_prob_prj = nn.LogSoftmax(dim=1)
  32. model = model.to(self.device)
  33. self.model = model
  34. self.model.eval()
  35. def translate_batch(self, src_seq, src_pos):
  36. ''' Translation work in one batch '''
  37. def get_inst_idx_to_tensor_position_map(inst_idx_list):
  38. ''' Indicate the position of an instance in a tensor. '''
  39. return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)}
  40. def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm):
  41. ''' Collect tensor parts associated to active instances. '''
  42. _, *d_hs = beamed_tensor.size()
  43. n_curr_active_inst = len(curr_active_inst_idx)
  44. new_shape = (n_curr_active_inst * n_bm, *d_hs)
  45. beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
  46. beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
  47. beamed_tensor = beamed_tensor.view(*new_shape)
  48. return beamed_tensor
  49. def collate_active_info(
  50. src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list):
  51. # Sentences which are still active are collected,
  52. # so the decoder will not run on completed sentences.
  53. n_prev_active_inst = len(inst_idx_to_position_map)
  54. active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list]
  55. active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device)
  56. active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm)
  57. active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm)
  58. active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
  59. return active_src_seq, active_src_enc, active_inst_idx_to_position_map
  60. def beam_decode_step(
  61. inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm):
  62. ''' Decode and update beam status, and then return active beam idx '''
  63. def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
  64. dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done]
  65. dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
  66. dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
  67. return dec_partial_seq
  68. def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm):
  69. dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device)
  70. dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1)
  71. return dec_partial_pos
  72. def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm):
  73. dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output)
  74. dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h
  75. word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1)
  76. word_prob = word_prob.view(n_active_inst, n_bm, -1)
  77. return word_prob
  78. def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map):
  79. active_inst_idx_list = []
  80. for inst_idx, inst_position in inst_idx_to_position_map.items():
  81. is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position])
  82. if not is_inst_complete:
  83. active_inst_idx_list += [inst_idx]
  84. return active_inst_idx_list
  85. n_active_inst = len(inst_idx_to_position_map)
  86. dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
  87. dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm)
  88. word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm)
  89. # Update the beam with predicted word prob information and collect incomplete instances
  90. active_inst_idx_list = collect_active_inst_idx_list(
  91. inst_dec_beams, word_prob, inst_idx_to_position_map)
  92. return active_inst_idx_list
  93. def collect_hypothesis_and_scores(inst_dec_beams, n_best):
  94. all_hyp, all_scores = [], []
  95. for inst_idx in range(len(inst_dec_beams)):
  96. scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
  97. all_scores += [scores[:n_best]]
  98. hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]]
  99. all_hyp += [hyps]
  100. return all_hyp, all_scores
  101. with torch.no_grad():
  102. #-- Encode
  103. src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device)
  104. src_enc, *_ = self.model.encoder(src_seq, src_pos)
  105. #-- Repeat data for beam search
  106. n_bm = self.opt.beam_size
  107. n_inst, len_s, d_h = src_enc.size()
  108. src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s)
  109. src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h)
  110. #-- Prepare beams
  111. inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)]
  112. #-- Bookkeeping for active or not
  113. active_inst_idx_list = list(range(n_inst))
  114. inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
  115. #-- Decode
  116. for len_dec_seq in range(1, self.model_opt.max_token_seq_len + 1):
  117. active_inst_idx_list = beam_decode_step(
  118. inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm)
  119. if not active_inst_idx_list:
  120. break # all instances have finished their path to <EOS>
  121. src_seq, src_enc, inst_idx_to_position_map = collate_active_info(
  122. src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list)
  123. batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.opt.n_best)
  124. return batch_hyp, batch_scores