|
|
@@ -172,6 +172,7 @@ class BeamSearchDecoder(nn.Cell): |
|
|
max_decode_length=64, |
|
|
max_decode_length=64, |
|
|
sos_id=2, |
|
|
sos_id=2, |
|
|
eos_id=3, |
|
|
eos_id=3, |
|
|
|
|
|
is_using_while=False, |
|
|
compute_type=mstype.float32): |
|
|
compute_type=mstype.float32): |
|
|
super(BeamSearchDecoder, self).__init__() |
|
|
super(BeamSearchDecoder, self).__init__() |
|
|
|
|
|
|
|
|
@@ -185,6 +186,7 @@ class BeamSearchDecoder(nn.Cell): |
|
|
self.cov_penalty_factor = cov_penalty_factor |
|
|
self.cov_penalty_factor = cov_penalty_factor |
|
|
self.max_decode_length = max_decode_length |
|
|
self.max_decode_length = max_decode_length |
|
|
self.decoder = decoder |
|
|
self.decoder = decoder |
|
|
|
|
|
self.is_using_while = is_using_while |
|
|
|
|
|
|
|
|
self.add = P.TensorAdd() |
|
|
self.add = P.TensorAdd() |
|
|
self.expand = P.ExpandDims() |
|
|
self.expand = P.ExpandDims() |
|
|
@@ -215,7 +217,12 @@ class BeamSearchDecoder(nn.Cell): |
|
|
self.gather_nd = P.GatherNd() |
|
|
self.gather_nd = P.GatherNd() |
|
|
|
|
|
|
|
|
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32) |
|
|
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32) |
|
|
self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32) |
|
|
|
|
|
|
|
|
if self.is_using_while: |
|
|
|
|
|
self.start = Tensor(0, dtype=mstype.int32) |
|
|
|
|
|
self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length], sos_id), |
|
|
|
|
|
mstype.int32) |
|
|
|
|
|
else: |
|
|
|
|
|
self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32) |
|
|
|
|
|
|
|
|
init_scores = np.tile(np.array([[0.] + [-INF] * (beam_width - 1)]), [batch_size, 1]) |
|
|
init_scores = np.tile(np.array([[0.] + [-INF] * (beam_width - 1)]), [batch_size, 1]) |
|
|
self.init_scores = Tensor(init_scores, mstype.float32) |
|
|
self.init_scores = Tensor(init_scores, mstype.float32) |
|
|
@@ -259,7 +266,7 @@ class BeamSearchDecoder(nn.Cell): |
|
|
self.sub = P.Sub() |
|
|
self.sub = P.Sub() |
|
|
|
|
|
|
|
|
def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs, |
|
|
def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs, |
|
|
state_seq, state_length, decoder_hidden_state=None, accu_attn_scores=None, |
|
|
|
|
|
|
|
|
state_seq, state_length, idx=None, decoder_hidden_state=None, accu_attn_scores=None, |
|
|
state_finished=None): |
|
|
state_finished=None): |
|
|
""" |
|
|
""" |
|
|
Beam search one_step output. |
|
|
Beam search one_step output. |
|
|
@@ -359,7 +366,13 @@ class BeamSearchDecoder(nn.Cell): |
|
|
self.hidden_size)) |
|
|
self.hidden_size)) |
|
|
|
|
|
|
|
|
# update state_seq |
|
|
# update state_seq |
|
|
state_seq = self.concat((seq, self.expand(word_indices, -1))) |
|
|
|
|
|
|
|
|
if self.is_using_while: |
|
|
|
|
|
state_seq_new = self.cast(seq, mstype.float32) |
|
|
|
|
|
word_indices_fp32 = self.cast(word_indices, mstype.float32) |
|
|
|
|
|
state_seq_new[:, :, idx] = word_indices_fp32 |
|
|
|
|
|
state_seq = self.cast(state_seq_new, mstype.int32) |
|
|
|
|
|
else: |
|
|
|
|
|
state_seq = self.concat((seq, self.expand(word_indices, -1))) |
|
|
|
|
|
|
|
|
cur_input_ids = self.reshape(word_indices, (-1, 1)) |
|
|
cur_input_ids = self.reshape(word_indices, (-1, 1)) |
|
|
state_log_probs = topk_scores |
|
|
state_log_probs = topk_scores |
|
|
@@ -388,11 +401,22 @@ class BeamSearchDecoder(nn.Cell): |
|
|
decoder_hidden_state = self.decoder_hidden_state |
|
|
decoder_hidden_state = self.decoder_hidden_state |
|
|
accu_attn_scores = self.accu_attn_scores |
|
|
accu_attn_scores = self.accu_attn_scores |
|
|
|
|
|
|
|
|
for _ in range(self.max_decode_length + 1): |
|
|
|
|
|
cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ |
|
|
|
|
|
state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, |
|
|
|
|
|
state_seq, state_length, decoder_hidden_state, accu_attn_scores, |
|
|
|
|
|
state_finished) |
|
|
|
|
|
|
|
|
if not self.is_using_while: |
|
|
|
|
|
for _ in range(self.max_decode_length + 1): |
|
|
|
|
|
cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ |
|
|
|
|
|
state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, |
|
|
|
|
|
state_seq, state_length, None, decoder_hidden_state, accu_attn_scores, |
|
|
|
|
|
state_finished) |
|
|
|
|
|
else: |
|
|
|
|
|
idx = self.start + 1 |
|
|
|
|
|
ends = self.start + self.max_decode_length + 1 |
|
|
|
|
|
while idx < ends: |
|
|
|
|
|
cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ |
|
|
|
|
|
state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, |
|
|
|
|
|
state_seq, state_length, idx, decoder_hidden_state, accu_attn_scores, |
|
|
|
|
|
state_finished) |
|
|
|
|
|
idx = idx + 1 |
|
|
|
|
|
|
|
|
# add length penalty scores |
|
|
# add length penalty scores |
|
|
penalty_len = self.length_penalty(state_length) |
|
|
penalty_len = self.length_penalty(state_length) |
|
|
# return penalty_len |
|
|
# return penalty_len |
|
|
@@ -408,6 +432,9 @@ class BeamSearchDecoder(nn.Cell): |
|
|
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1))) |
|
|
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1))) |
|
|
# sort sequence and attention scores |
|
|
# sort sequence and attention scores |
|
|
predicted_ids = self.gather_nd(state_seq, gather_indices) |
|
|
predicted_ids = self.gather_nd(state_seq, gather_indices) |
|
|
predicted_ids = predicted_ids[:, 0:1, 1:(self.max_decode_length + 1)] |
|
|
|
|
|
|
|
|
if not self.is_using_while: |
|
|
|
|
|
predicted_ids = predicted_ids[:, 0:1, 1:(self.max_decode_length + 1)] |
|
|
|
|
|
else: |
|
|
|
|
|
predicted_ids = predicted_ids[:, 0:1, :self.max_decode_length] |
|
|
|
|
|
|
|
|
return predicted_ids |
|
|
return predicted_ids |