|
|
|
@@ -219,7 +219,7 @@ class BeamSearchDecoder(nn.Cell): |
|
|
|
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32) |
|
|
|
if self.is_using_while: |
|
|
|
self.start = Tensor(0, dtype=mstype.int32) |
|
|
|
self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length], sos_id), |
|
|
|
self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length + 1], sos_id), |
|
|
|
mstype.int32) |
|
|
|
else: |
|
|
|
self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32) |
|
|
|
@@ -402,7 +402,7 @@ class BeamSearchDecoder(nn.Cell): |
|
|
|
accu_attn_scores = self.accu_attn_scores |
|
|
|
|
|
|
|
if not self.is_using_while: |
|
|
|
for _ in range(self.max_decode_length + 1): |
|
|
|
for _ in range(self.max_decode_length): |
|
|
|
cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ |
|
|
|
state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, |
|
|
|
state_seq, state_length, None, decoder_hidden_state, accu_attn_scores, |
|
|
|
|