You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Beam.py 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. """ Manage beam search info structure.
  2. Heavily borrowed from OpenNMT-py.
  3. For code in OpenNMT-py, please check the following link:
  4. https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/Beam.py
  5. """
  6. import torch
  7. import numpy as np
  8. import transformer.Constants as Constants
  9. class Beam():
  10. ''' Beam search '''
  11. def __init__(self, size, device=False):
  12. self.size = size
  13. self._done = False
  14. # The score for each translation on the beam.
  15. self.scores = torch.zeros((size,), dtype=torch.float, device=device)
  16. self.all_scores = []
  17. # The backpointers at each time-step.
  18. self.prev_ks = []
  19. # The outputs at each time-step.
  20. self.next_ys = [torch.full((size,), Constants.PAD, dtype=torch.long, device=device)]
  21. self.next_ys[0][0] = Constants.BOS
  22. def get_current_state(self):
  23. "Get the outputs for the current timestep."
  24. return self.get_tentative_hypothesis()
  25. def get_current_origin(self):
  26. "Get the backpointers for the current timestep."
  27. return self.prev_ks[-1]
  28. @property
  29. def done(self):
  30. return self._done
  31. def advance(self, word_prob):
  32. "Update beam status and check if finished or not."
  33. num_words = word_prob.size(1)
  34. # Sum the previous scores.
  35. if len(self.prev_ks) > 0:
  36. beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
  37. else:
  38. beam_lk = word_prob[0]
  39. flat_beam_lk = beam_lk.view(-1)
  40. best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort
  41. best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort
  42. self.all_scores.append(self.scores)
  43. self.scores = best_scores
  44. # bestScoresId is flattened as a (beam x word) array,
  45. # so we need to calculate which word and beam each score came from
  46. prev_k = best_scores_id / num_words
  47. self.prev_ks.append(prev_k)
  48. self.next_ys.append(best_scores_id - prev_k * num_words)
  49. # End condition is when top-of-beam is EOS.
  50. if self.next_ys[-1][0].item() == Constants.EOS:
  51. self._done = True
  52. self.all_scores.append(self.scores)
  53. return self._done
  54. def sort_scores(self):
  55. "Sort the scores."
  56. return torch.sort(self.scores, 0, True)
  57. def get_the_best_score_and_idx(self):
  58. "Get the score of the best in the beam."
  59. scores, ids = self.sort_scores()
  60. return scores[1], ids[1]
  61. def get_tentative_hypothesis(self):
  62. "Get the decoded sequence for the current timestep."
  63. if len(self.next_ys) == 1:
  64. dec_seq = self.next_ys[0].unsqueeze(1)
  65. else:
  66. _, keys = self.sort_scores()
  67. hyps = [self.get_hypothesis(k) for k in keys]
  68. hyps = [[Constants.BOS] + h for h in hyps]
  69. dec_seq = torch.LongTensor(hyps)
  70. return dec_seq
  71. def get_hypothesis(self, k):
  72. """ Walk back to construct the full hypothesis. """
  73. hyp = []
  74. for j in range(len(self.prev_ks) - 1, -1, -1):
  75. hyp.append(self.next_ys[j+1][k])
  76. k = self.prev_ks[j][k]
  77. return list(map(lambda x: x.item(), hyp[::-1]))