| @@ -168,6 +168,11 @@ class BeamSearchDecoder(nn.Cell): | |||||
| self.concat = P.Concat(axis=-1) | self.concat = P.Concat(axis=-1) | ||||
| self.gather_nd = P.GatherNd() | self.gather_nd = P.GatherNd() | ||||
| self.greater_equal = P.GreaterEqual() | |||||
| self.sub = P.Sub() | |||||
| self.cast = P.Cast() | |||||
| self.zeroslike = P.ZerosLike() | |||||
| # init inputs and states | # init inputs and states | ||||
| self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32) | self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32) | ||||
| self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32) | self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32) | ||||
| @@ -199,8 +204,19 @@ class BeamSearchDecoder(nn.Cell): | |||||
| topk_scores, topk_indices = self.topk(flat_scores, self.beam_width) | topk_scores, topk_indices = self.topk(flat_scores, self.beam_width) | ||||
| # convert to beam and word indices | # convert to beam and word indices | ||||
| beam_indices = self.floor_div(topk_indices, self.vocab_size_tensor) | |||||
| word_indices = self.mod(topk_indices, self.vocab_size_tensor) | |||||
| #beam_indices = self.floor_div(topk_indices, self.vocab_size_tensor) | |||||
| #word_indices = self.mod(topk_indices, self.vocab_size_tensor) | |||||
| #====================================================================== | |||||
| #replace floor_div and mod op, since these two ops only support fp16 on | |||||
| #Ascend310, which will cause overflow. | |||||
| temp = topk_indices | |||||
| beam_indices = self.zeroslike(topk_indices) | |||||
| for _ in range(self.beam_width - 1): | |||||
| temp = self.sub(temp, self.vocab_size_tensor) | |||||
| res = self.cast(self.greater_equal(temp, 0), mstype.int32) | |||||
| beam_indices = beam_indices + res | |||||
| word_indices = topk_indices - beam_indices * self.vocab_size_tensor | |||||
| #====================================================================== | |||||
| # mask finished indices | # mask finished indices | ||||
| beam_indices = self.select(state_finished, self.beam_ids, beam_indices) | beam_indices = self.select(state_finished, self.beam_ids, beam_indices) | ||||