|
|
|
@@ -781,95 +781,22 @@ class TransformerDecoder(nn.Cell): |
|
|
|
super(TransformerDecoder, self).__init__() |
|
|
|
self.num_hidden_layers = num_hidden_layers |
|
|
|
|
|
|
|
# wait to be supported |
|
|
|
# layers = [] |
|
|
|
# for _ in range(num_hidden_layers): |
|
|
|
# layer = DecoderCell(batch_size=batch_size, |
|
|
|
# hidden_size=hidden_size, |
|
|
|
# seq_length=seq_length, |
|
|
|
# enc_seq_length=enc_seq_length, |
|
|
|
# num_attention_heads=num_attention_heads, |
|
|
|
# intermediate_size=intermediate_size, |
|
|
|
# attention_probs_dropout_prob=attention_probs_dropout_prob, |
|
|
|
# use_one_hot_embeddings=use_one_hot_embeddings, |
|
|
|
# initializer_range=initializer_range, |
|
|
|
# hidden_dropout_prob=hidden_dropout_prob, |
|
|
|
# hidden_act=hidden_act, |
|
|
|
# compute_type=compute_type) |
|
|
|
# layers.append(layer) |
|
|
|
# self.layers = nn.CellList(layers) |
|
|
|
self.layer0 = DecoderCell(batch_size=batch_size, |
|
|
|
hidden_size=hidden_size, |
|
|
|
seq_length=seq_length, |
|
|
|
enc_seq_length=enc_seq_length, |
|
|
|
num_attention_heads=num_attention_heads, |
|
|
|
intermediate_size=intermediate_size, |
|
|
|
attention_probs_dropout_prob=attention_probs_dropout_prob, |
|
|
|
use_one_hot_embeddings=use_one_hot_embeddings, |
|
|
|
initializer_range=initializer_range, |
|
|
|
hidden_dropout_prob=hidden_dropout_prob, |
|
|
|
hidden_act=hidden_act, |
|
|
|
compute_type=compute_type) |
|
|
|
self.layer1 = DecoderCell(batch_size=batch_size, |
|
|
|
hidden_size=hidden_size, |
|
|
|
seq_length=seq_length, |
|
|
|
enc_seq_length=enc_seq_length, |
|
|
|
num_attention_heads=num_attention_heads, |
|
|
|
intermediate_size=intermediate_size, |
|
|
|
attention_probs_dropout_prob=attention_probs_dropout_prob, |
|
|
|
use_one_hot_embeddings=use_one_hot_embeddings, |
|
|
|
initializer_range=initializer_range, |
|
|
|
hidden_dropout_prob=hidden_dropout_prob, |
|
|
|
hidden_act=hidden_act, |
|
|
|
compute_type=compute_type) |
|
|
|
self.layer2 = DecoderCell(batch_size=batch_size, |
|
|
|
hidden_size=hidden_size, |
|
|
|
seq_length=seq_length, |
|
|
|
enc_seq_length=enc_seq_length, |
|
|
|
num_attention_heads=num_attention_heads, |
|
|
|
intermediate_size=intermediate_size, |
|
|
|
attention_probs_dropout_prob=attention_probs_dropout_prob, |
|
|
|
use_one_hot_embeddings=use_one_hot_embeddings, |
|
|
|
initializer_range=initializer_range, |
|
|
|
hidden_dropout_prob=hidden_dropout_prob, |
|
|
|
hidden_act=hidden_act, |
|
|
|
compute_type=compute_type) |
|
|
|
self.layer3 = DecoderCell(batch_size=batch_size, |
|
|
|
hidden_size=hidden_size, |
|
|
|
seq_length=seq_length, |
|
|
|
enc_seq_length=enc_seq_length, |
|
|
|
num_attention_heads=num_attention_heads, |
|
|
|
intermediate_size=intermediate_size, |
|
|
|
attention_probs_dropout_prob=attention_probs_dropout_prob, |
|
|
|
use_one_hot_embeddings=use_one_hot_embeddings, |
|
|
|
initializer_range=initializer_range, |
|
|
|
hidden_dropout_prob=hidden_dropout_prob, |
|
|
|
hidden_act=hidden_act, |
|
|
|
compute_type=compute_type) |
|
|
|
self.layer4 = DecoderCell(batch_size=batch_size, |
|
|
|
hidden_size=hidden_size, |
|
|
|
seq_length=seq_length, |
|
|
|
enc_seq_length=enc_seq_length, |
|
|
|
num_attention_heads=num_attention_heads, |
|
|
|
intermediate_size=intermediate_size, |
|
|
|
attention_probs_dropout_prob=attention_probs_dropout_prob, |
|
|
|
use_one_hot_embeddings=use_one_hot_embeddings, |
|
|
|
initializer_range=initializer_range, |
|
|
|
hidden_dropout_prob=hidden_dropout_prob, |
|
|
|
hidden_act=hidden_act, |
|
|
|
compute_type=compute_type) |
|
|
|
self.layer5 = DecoderCell(batch_size=batch_size, |
|
|
|
hidden_size=hidden_size, |
|
|
|
seq_length=seq_length, |
|
|
|
enc_seq_length=enc_seq_length, |
|
|
|
num_attention_heads=num_attention_heads, |
|
|
|
intermediate_size=intermediate_size, |
|
|
|
attention_probs_dropout_prob=attention_probs_dropout_prob, |
|
|
|
use_one_hot_embeddings=use_one_hot_embeddings, |
|
|
|
initializer_range=initializer_range, |
|
|
|
hidden_dropout_prob=hidden_dropout_prob, |
|
|
|
hidden_act=hidden_act, |
|
|
|
compute_type=compute_type) |
|
|
|
layers = [] |
|
|
|
for _ in range(num_hidden_layers): |
|
|
|
layer = DecoderCell(batch_size=batch_size, |
|
|
|
hidden_size=hidden_size, |
|
|
|
seq_length=seq_length, |
|
|
|
enc_seq_length=enc_seq_length, |
|
|
|
num_attention_heads=num_attention_heads, |
|
|
|
intermediate_size=intermediate_size, |
|
|
|
attention_probs_dropout_prob=attention_probs_dropout_prob, |
|
|
|
use_one_hot_embeddings=use_one_hot_embeddings, |
|
|
|
initializer_range=initializer_range, |
|
|
|
hidden_dropout_prob=hidden_dropout_prob, |
|
|
|
hidden_act=hidden_act, |
|
|
|
compute_type=compute_type) |
|
|
|
layers.append(layer) |
|
|
|
self.layers = nn.CellList(layers) |
|
|
|
|
|
|
|
self.layer_preprocess = LayerPreprocess(in_channels=hidden_size) |
|
|
|
|
|
|
|
@@ -880,16 +807,9 @@ class TransformerDecoder(nn.Cell): |
|
|
|
def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask): |
|
|
|
prev_output = self.reshape(input_tensor, self.shape) |
|
|
|
|
|
|
|
# wait to be supported |
|
|
|
# for layer_module in self.layers: |
|
|
|
# layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask) |
|
|
|
# prev_output = layer_output |
|
|
|
prev_output = self.layer0(prev_output, attention_mask, enc_states, enc_attention_mask) |
|
|
|
prev_output = self.layer1(prev_output, attention_mask, enc_states, enc_attention_mask) |
|
|
|
prev_output = self.layer2(prev_output, attention_mask, enc_states, enc_attention_mask) |
|
|
|
prev_output = self.layer3(prev_output, attention_mask, enc_states, enc_attention_mask) |
|
|
|
prev_output = self.layer4(prev_output, attention_mask, enc_states, enc_attention_mask) |
|
|
|
prev_output = self.layer5(prev_output, attention_mask, enc_states, enc_attention_mask) |
|
|
|
for layer_module in self.layers: |
|
|
|
layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask) |
|
|
|
prev_output = layer_output |
|
|
|
|
|
|
|
prev_output = self.layer_preprocess(prev_output) |
|
|
|
output = self.reshape(prev_output, self.out_shape) |
|
|
|
|