|
|
|
@@ -224,6 +224,11 @@ class BeamSearchDecoder(nn.Cell): |
|
|
|
self.one = Tensor(1, mstype.int32) |
|
|
|
self.prob_concat = P.Concat(axis=1) |
|
|
|
|
|
|
|
self.greater_equal = P.GreaterEqual() |
|
|
|
self.sub = P.Sub() |
|
|
|
self.cast = P.Cast() |
|
|
|
self.zeroslike = P.ZerosLike() |
|
|
|
|
|
|
|
def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs, state_seq, state_finished, |
|
|
|
state_length, entire_log_probs): |
|
|
|
""" |
|
|
|
@@ -261,8 +266,19 @@ class BeamSearchDecoder(nn.Cell): |
|
|
|
topk_scores, topk_indices = self.topk(flat_scores, self.beam_width) |
|
|
|
|
|
|
|
# convert to beam and word indices, [batch, beam] |
|
|
|
beam_indices = self.floor_div(topk_indices, self.vocab_size_tensor) |
|
|
|
word_indices = self.mod(topk_indices, self.vocab_size_tensor) |
|
|
|
# beam_indices = self.floor_div(topk_indices, self.vocab_size_tensor) |
|
|
|
# word_indices = self.mod(topk_indices, self.vocab_size_tensor) |
|
|
|
# ====================================================================== |
|
|
|
# replace floor_div and mod op, since these two ops only support fp16 on |
|
|
|
# Ascend310, which will cause overflow. |
|
|
|
temp = topk_indices |
|
|
|
beam_indices = self.zeroslike(topk_indices) |
|
|
|
for _ in range(self.beam_width - 1): |
|
|
|
temp = self.sub(temp, self.vocab_size_tensor) |
|
|
|
res = self.cast(self.greater_equal(temp, 0), mstype.int32) |
|
|
|
beam_indices = beam_indices + res |
|
|
|
word_indices = topk_indices - beam_indices * self.vocab_size_tensor |
|
|
|
#====================================================================== |
|
|
|
|
|
|
|
current_word_pro = self.gather_nd( |
|
|
|
log_probs, |
|
|
|
|