| @@ -224,6 +224,11 @@ class BeamSearchDecoder(nn.Cell): | |||||
| self.one = Tensor(1, mstype.int32) | self.one = Tensor(1, mstype.int32) | ||||
| self.prob_concat = P.Concat(axis=1) | self.prob_concat = P.Concat(axis=1) | ||||
| self.greater_equal = P.GreaterEqual() | |||||
| self.sub = P.Sub() | |||||
| self.cast = P.Cast() | |||||
| self.zeroslike = P.ZerosLike() | |||||
| def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs, state_seq, state_finished, | def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs, state_seq, state_finished, | ||||
| state_length, entire_log_probs): | state_length, entire_log_probs): | ||||
| """ | """ | ||||
| @@ -261,8 +266,19 @@ class BeamSearchDecoder(nn.Cell): | |||||
| topk_scores, topk_indices = self.topk(flat_scores, self.beam_width) | topk_scores, topk_indices = self.topk(flat_scores, self.beam_width) | ||||
| # convert to beam and word indices, [batch, beam] | # convert to beam and word indices, [batch, beam] | ||||
| beam_indices = self.floor_div(topk_indices, self.vocab_size_tensor) | |||||
| word_indices = self.mod(topk_indices, self.vocab_size_tensor) | |||||
| # beam_indices = self.floor_div(topk_indices, self.vocab_size_tensor) | |||||
| # word_indices = self.mod(topk_indices, self.vocab_size_tensor) | |||||
| # ====================================================================== | |||||
| # replace floor_div and mod op, since these two ops only support fp16 on | |||||
| # Ascend310, which will cause overflow. | |||||
| temp = topk_indices | |||||
| beam_indices = self.zeroslike(topk_indices) | |||||
| for _ in range(self.beam_width - 1): | |||||
| temp = self.sub(temp, self.vocab_size_tensor) | |||||
| res = self.cast(self.greater_equal(temp, 0), mstype.int32) | |||||
| beam_indices = beam_indices + res | |||||
| word_indices = topk_indices - beam_indices * self.vocab_size_tensor | |||||
| #====================================================================== | |||||
| current_word_pro = self.gather_nd( | current_word_pro = self.gather_nd( | ||||
| log_probs, | log_probs, | ||||