diff --git a/model_zoo/official/nlp/transformer/src/beam_search.py b/model_zoo/official/nlp/transformer/src/beam_search.py index ffa3dfbac3..1157cb259a 100644 --- a/model_zoo/official/nlp/transformer/src/beam_search.py +++ b/model_zoo/official/nlp/transformer/src/beam_search.py @@ -168,6 +168,11 @@ class BeamSearchDecoder(nn.Cell): self.concat = P.Concat(axis=-1) self.gather_nd = P.GatherNd() + self.greater_equal = P.GreaterEqual() + self.sub = P.Sub() + self.cast = P.Cast() + self.zeroslike = P.ZerosLike() + # init inputs and states self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32) self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32) @@ -199,8 +204,19 @@ class BeamSearchDecoder(nn.Cell): topk_scores, topk_indices = self.topk(flat_scores, self.beam_width) # convert to beam and word indices - beam_indices = self.floor_div(topk_indices, self.vocab_size_tensor) - word_indices = self.mod(topk_indices, self.vocab_size_tensor) + #beam_indices = self.floor_div(topk_indices, self.vocab_size_tensor) + #word_indices = self.mod(topk_indices, self.vocab_size_tensor) + #====================================================================== + #replace floor_div and mod op, since these two ops only support fp16 on + #Ascend310, which will cause overflow. + temp = topk_indices + beam_indices = self.zeroslike(topk_indices) + for _ in range(self.beam_width - 1): + temp = self.sub(temp, self.vocab_size_tensor) + res = self.cast(self.greater_equal(temp, 0), mstype.int32) + beam_indices = beam_indices + res + word_indices = topk_indices - beam_indices * self.vocab_size_tensor + #====================================================================== # mask finished indices beam_indices = self.select(state_finished, self.beam_ids, beam_indices)