|
|
|
@@ -23,6 +23,7 @@ import mindspore.ops.functional as F |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore.common.tensor import Tensor |
|
|
|
from mindspore.common.parameter import Parameter |
|
|
|
from mindspore.ops.primitive import constexpr |
|
|
|
from .beam_search import BeamSearchDecoder, TileBeam |
|
|
|
from .weight_init import normal_weight, weight_variable |
|
|
|
|
|
|
|
@@ -296,8 +297,6 @@ class MultiheadAttention(nn.Cell): |
|
|
|
from_tensor_width, |
|
|
|
to_tensor_width, |
|
|
|
out_tensor_width, |
|
|
|
from_seq_length, |
|
|
|
to_seq_length, |
|
|
|
num_attention_heads=1, |
|
|
|
size_per_head=512, |
|
|
|
query_act=None, |
|
|
|
@@ -312,12 +311,13 @@ class MultiheadAttention(nn.Cell): |
|
|
|
compute_type=mstype.float32): |
|
|
|
super(MultiheadAttention, self).__init__() |
|
|
|
self.batch_size = batch_size |
|
|
|
self.from_seq_length = from_seq_length |
|
|
|
self.to_seq_length = to_seq_length |
|
|
|
self.num_attention_heads = num_attention_heads |
|
|
|
self.size_per_head = size_per_head |
|
|
|
self.has_attention_mask = has_attention_mask |
|
|
|
assert has_attention_mask |
|
|
|
self.use_one_hot_embeddings = use_one_hot_embeddings |
|
|
|
self.initializer_range = initializer_range |
|
|
|
self.do_return_2d_tensor = do_return_2d_tensor |
|
|
|
|
|
|
|
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type) |
|
|
|
self.reshape = P.Reshape() |
|
|
|
@@ -345,9 +345,6 @@ class MultiheadAttention(nn.Cell): |
|
|
|
has_bias=False, |
|
|
|
weight_init=weight_variable([out_tensor_width, units])).to_float(compute_type) |
|
|
|
|
|
|
|
self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) |
|
|
|
self.shape_to = (batch_size, to_seq_length, num_attention_heads, size_per_head) |
|
|
|
|
|
|
|
self.matmul_trans_b = P.BatchMatMul(transpose_b=True) |
|
|
|
self.multiply = P.Mul() |
|
|
|
self.transpose = P.Transpose() |
|
|
|
@@ -368,27 +365,33 @@ class MultiheadAttention(nn.Cell): |
|
|
|
self.add = P.TensorAdd() |
|
|
|
self.cast = P.Cast() |
|
|
|
self.get_dtype = P.DType() |
|
|
|
if do_return_2d_tensor: |
|
|
|
self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head) |
|
|
|
if from_seq_length == -1: |
|
|
|
self.shape_return = (-1, num_attention_heads * size_per_head) |
|
|
|
else: |
|
|
|
self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head) |
|
|
|
|
|
|
|
self.cast_compute_type = CastWrapper(dst_type=compute_type) |
|
|
|
self.softmax_cast = P.Cast() |
|
|
|
|
|
|
|
def construct(self, from_tensor, to_tensor, attention_mask=None): |
|
|
|
"""reshape 2d/3d input tensors to 2d""" |
|
|
|
def construct(self, from_tensor, to_tensor, seq_length, enc_seq_length, attention_mask=None): |
|
|
|
"""Apply multihead attention.""" |
|
|
|
from_seq_length = seq_length |
|
|
|
to_seq_length = enc_seq_length |
|
|
|
shape_from = (self.batch_size, from_seq_length, self.num_attention_heads, self.size_per_head) |
|
|
|
shape_to = (self.batch_size, to_seq_length, self.num_attention_heads, self.size_per_head) |
|
|
|
if self.do_return_2d_tensor: |
|
|
|
shape_return = (self.batch_size * from_seq_length, self.num_attention_heads * self.size_per_head) |
|
|
|
if from_seq_length == -1: |
|
|
|
shape_return = (-1, self.num_attention_heads * self.size_per_head) |
|
|
|
else: |
|
|
|
shape_return = (self.batch_size, from_seq_length, self.num_attention_heads * self.size_per_head) |
|
|
|
|
|
|
|
# reshape 2d/3d input tensors to 2d |
|
|
|
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) |
|
|
|
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d) |
|
|
|
query_out = self.query_layer(from_tensor_2d) |
|
|
|
key_out = self.key_layer(to_tensor_2d) |
|
|
|
value_out = self.value_layer(to_tensor_2d) |
|
|
|
|
|
|
|
query_layer = self.reshape(query_out, self.shape_from) |
|
|
|
query_layer = self.reshape(query_out, shape_from) |
|
|
|
query_layer = self.transpose(query_layer, self.trans_shape) |
|
|
|
key_layer = self.reshape(key_out, self.shape_to) |
|
|
|
key_layer = self.reshape(key_out, shape_to) |
|
|
|
key_layer = self.transpose(key_layer, self.trans_shape) |
|
|
|
|
|
|
|
attention_scores = self.matmul_trans_b(query_layer, key_layer) |
|
|
|
@@ -407,12 +410,12 @@ class MultiheadAttention(nn.Cell): |
|
|
|
if self.use_dropout: |
|
|
|
attention_probs = self.dropout(attention_probs) |
|
|
|
|
|
|
|
value_layer = self.reshape(value_out, self.shape_to) |
|
|
|
value_layer = self.reshape(value_out, shape_to) |
|
|
|
value_layer = self.transpose(value_layer, self.trans_shape) |
|
|
|
context_layer = self.matmul(attention_probs, value_layer) |
|
|
|
|
|
|
|
context_layer = self.transpose(context_layer, self.trans_shape) |
|
|
|
context_layer = self.reshape(context_layer, self.shape_return) |
|
|
|
context_layer = self.reshape(context_layer, shape_return) |
|
|
|
context_layer = self.out_layer(context_layer) |
|
|
|
return context_layer |
|
|
|
|
|
|
|
@@ -438,8 +441,6 @@ class SelfAttention(nn.Cell): |
|
|
|
""" |
|
|
|
def __init__(self, |
|
|
|
batch_size, |
|
|
|
from_seq_length, |
|
|
|
to_seq_length, |
|
|
|
hidden_size, |
|
|
|
num_attention_heads=16, |
|
|
|
attention_probs_dropout_prob=0.1, |
|
|
|
@@ -461,8 +462,6 @@ class SelfAttention(nn.Cell): |
|
|
|
from_tensor_width=hidden_size, |
|
|
|
to_tensor_width=hidden_size, |
|
|
|
out_tensor_width=hidden_size, |
|
|
|
from_seq_length=from_seq_length, |
|
|
|
to_seq_length=to_seq_length, |
|
|
|
num_attention_heads=num_attention_heads, |
|
|
|
size_per_head=self.size_per_head, |
|
|
|
attention_probs_dropout_prob=attention_probs_dropout_prob, |
|
|
|
@@ -477,7 +476,7 @@ class SelfAttention(nn.Cell): |
|
|
|
|
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.shape = (-1, hidden_size) |
|
|
|
def construct(self, input_tensor, memory_tensor, attention_mask): |
|
|
|
def construct(self, input_tensor, memory_tensor, attention_mask, seq_length, enc_seq_length): |
|
|
|
"""Apply self-attention.""" |
|
|
|
input_tensor = self.reshape(input_tensor, self.shape) |
|
|
|
memory_tensor = self.reshape(memory_tensor, self.shape) |
|
|
|
@@ -487,7 +486,7 @@ class SelfAttention(nn.Cell): |
|
|
|
if not self.is_encdec_att: |
|
|
|
memory_tensor = output |
|
|
|
|
|
|
|
attention_output = self.attention(output, memory_tensor, attention_mask) |
|
|
|
attention_output = self.attention(output, memory_tensor, seq_length, enc_seq_length, attention_mask) |
|
|
|
output = self.postprocess(attention_output, input_tensor) |
|
|
|
return output |
|
|
|
|
|
|
|
@@ -563,7 +562,6 @@ class EncoderCell(nn.Cell): |
|
|
|
def __init__(self, |
|
|
|
batch_size, |
|
|
|
hidden_size=1024, |
|
|
|
seq_length=128, |
|
|
|
num_attention_heads=16, |
|
|
|
intermediate_size=4096, |
|
|
|
attention_probs_dropout_prob=0.1, |
|
|
|
@@ -576,8 +574,6 @@ class EncoderCell(nn.Cell): |
|
|
|
self.attention = SelfAttention( |
|
|
|
batch_size=batch_size, |
|
|
|
hidden_size=hidden_size, |
|
|
|
from_seq_length=seq_length, |
|
|
|
to_seq_length=seq_length, |
|
|
|
num_attention_heads=num_attention_heads, |
|
|
|
attention_probs_dropout_prob=attention_probs_dropout_prob, |
|
|
|
use_one_hot_embeddings=use_one_hot_embeddings, |
|
|
|
@@ -594,9 +590,9 @@ class EncoderCell(nn.Cell): |
|
|
|
hidden_dropout_prob=hidden_dropout_prob, |
|
|
|
compute_type=compute_type) |
|
|
|
|
|
|
|
def construct(self, hidden_states, attention_mask): |
|
|
|
def construct(self, hidden_states, attention_mask, seq_length): |
|
|
|
# self-attention with ln, res |
|
|
|
attention_output = self.attention(hidden_states, hidden_states, attention_mask) |
|
|
|
attention_output = self.attention(hidden_states, hidden_states, attention_mask, seq_length, seq_length) |
|
|
|
# feed forward with ln, res |
|
|
|
output = self.feedforward(attention_output) |
|
|
|
return output |
|
|
|
@@ -624,7 +620,6 @@ class TransformerEncoder(nn.Cell): |
|
|
|
def __init__(self, |
|
|
|
batch_size, |
|
|
|
hidden_size, |
|
|
|
seq_length, |
|
|
|
num_hidden_layers, |
|
|
|
num_attention_heads=16, |
|
|
|
intermediate_size=4096, |
|
|
|
@@ -636,12 +631,13 @@ class TransformerEncoder(nn.Cell): |
|
|
|
compute_type=mstype.float32): |
|
|
|
super(TransformerEncoder, self).__init__() |
|
|
|
self.num_hidden_layers = num_hidden_layers |
|
|
|
self.batch_size = batch_size |
|
|
|
self.hidden_size = hidden_size |
|
|
|
|
|
|
|
layers = [] |
|
|
|
for _ in range(num_hidden_layers): |
|
|
|
layer = EncoderCell(batch_size=batch_size, |
|
|
|
hidden_size=hidden_size, |
|
|
|
seq_length=seq_length, |
|
|
|
num_attention_heads=num_attention_heads, |
|
|
|
intermediate_size=intermediate_size, |
|
|
|
attention_probs_dropout_prob=attention_probs_dropout_prob, |
|
|
|
@@ -657,17 +653,18 @@ class TransformerEncoder(nn.Cell): |
|
|
|
|
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.shape = (-1, hidden_size) |
|
|
|
self.out_shape = (batch_size, seq_length, hidden_size) |
|
|
|
|
|
|
|
def construct(self, input_tensor, attention_mask): |
|
|
|
def construct(self, input_tensor, attention_mask, seq_length): |
|
|
|
"""Apply encoder.""" |
|
|
|
out_shape = (self.batch_size, seq_length, self.hidden_size) |
|
|
|
prev_output = self.reshape(input_tensor, self.shape) |
|
|
|
|
|
|
|
for layer_module in self.layers: |
|
|
|
layer_output = layer_module(prev_output, attention_mask) |
|
|
|
layer_output = layer_module(prev_output, attention_mask, seq_length) |
|
|
|
prev_output = layer_output |
|
|
|
|
|
|
|
prev_output = self.layer_preprocess(prev_output) |
|
|
|
output = self.reshape(prev_output, self.out_shape) |
|
|
|
output = self.reshape(prev_output, out_shape) |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
@@ -693,8 +690,6 @@ class DecoderCell(nn.Cell): |
|
|
|
def __init__(self, |
|
|
|
batch_size, |
|
|
|
hidden_size=1024, |
|
|
|
seq_length=128, |
|
|
|
enc_seq_length=128, |
|
|
|
num_attention_heads=12, |
|
|
|
intermediate_size=4096, |
|
|
|
attention_probs_dropout_prob=0.02, |
|
|
|
@@ -707,8 +702,6 @@ class DecoderCell(nn.Cell): |
|
|
|
self.self_attention = SelfAttention( |
|
|
|
batch_size=batch_size, |
|
|
|
hidden_size=hidden_size, |
|
|
|
from_seq_length=seq_length, |
|
|
|
to_seq_length=seq_length, |
|
|
|
num_attention_heads=num_attention_heads, |
|
|
|
attention_probs_dropout_prob=attention_probs_dropout_prob, |
|
|
|
use_one_hot_embeddings=use_one_hot_embeddings, |
|
|
|
@@ -719,8 +712,6 @@ class DecoderCell(nn.Cell): |
|
|
|
self.cross_attention = SelfAttention( |
|
|
|
batch_size=batch_size, |
|
|
|
hidden_size=hidden_size, |
|
|
|
from_seq_length=seq_length, |
|
|
|
to_seq_length=enc_seq_length, |
|
|
|
num_attention_heads=num_attention_heads, |
|
|
|
attention_probs_dropout_prob=attention_probs_dropout_prob, |
|
|
|
use_one_hot_embeddings=use_one_hot_embeddings, |
|
|
|
@@ -737,11 +728,12 @@ class DecoderCell(nn.Cell): |
|
|
|
hidden_dropout_prob=hidden_dropout_prob, |
|
|
|
compute_type=compute_type) |
|
|
|
|
|
|
|
def construct(self, hidden_states, attention_mask, enc_states, enc_attention_mask): |
|
|
|
def construct(self, hidden_states, attention_mask, enc_states, enc_attention_mask, seq_length, enc_seq_length): |
|
|
|
# self-attention with ln, res |
|
|
|
attention_output = self.self_attention(hidden_states, hidden_states, attention_mask) |
|
|
|
attention_output = self.self_attention(hidden_states, hidden_states, attention_mask, seq_length, seq_length) |
|
|
|
# cross-attention with ln, res |
|
|
|
attention_output = self.cross_attention(attention_output, enc_states, enc_attention_mask) |
|
|
|
attention_output = self.cross_attention(attention_output, enc_states, enc_attention_mask, |
|
|
|
seq_length, enc_seq_length) |
|
|
|
# feed forward with ln, res |
|
|
|
output = self.feedforward(attention_output) |
|
|
|
return output |
|
|
|
@@ -770,8 +762,6 @@ class TransformerDecoder(nn.Cell): |
|
|
|
def __init__(self, |
|
|
|
batch_size, |
|
|
|
hidden_size, |
|
|
|
seq_length, |
|
|
|
enc_seq_length, |
|
|
|
num_hidden_layers, |
|
|
|
num_attention_heads=16, |
|
|
|
intermediate_size=4096, |
|
|
|
@@ -788,8 +778,6 @@ class TransformerDecoder(nn.Cell): |
|
|
|
for _ in range(num_hidden_layers): |
|
|
|
layer = DecoderCell(batch_size=batch_size, |
|
|
|
hidden_size=hidden_size, |
|
|
|
seq_length=seq_length, |
|
|
|
enc_seq_length=enc_seq_length, |
|
|
|
num_attention_heads=num_attention_heads, |
|
|
|
intermediate_size=intermediate_size, |
|
|
|
attention_probs_dropout_prob=attention_probs_dropout_prob, |
|
|
|
@@ -805,17 +793,21 @@ class TransformerDecoder(nn.Cell): |
|
|
|
|
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.shape = (-1, hidden_size) |
|
|
|
self.out_shape = (batch_size, seq_length, hidden_size) |
|
|
|
self.hidden_size = hidden_size |
|
|
|
self.batch_size = batch_size |
|
|
|
|
|
|
|
def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask): |
|
|
|
def construct(self, input_tensor, attention_mask, enc_states, enc_attention_mask, seq_length, enc_seq_length): |
|
|
|
"""Apply decoder.""" |
|
|
|
out_shape = (self.batch_size, seq_length, self.hidden_size) |
|
|
|
prev_output = self.reshape(input_tensor, self.shape) |
|
|
|
|
|
|
|
for layer_module in self.layers: |
|
|
|
layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask) |
|
|
|
layer_output = layer_module(prev_output, attention_mask, enc_states, enc_attention_mask, |
|
|
|
seq_length, enc_seq_length) |
|
|
|
prev_output = layer_output |
|
|
|
|
|
|
|
prev_output = self.layer_preprocess(prev_output) |
|
|
|
output = self.reshape(prev_output, self.out_shape) |
|
|
|
output = self.reshape(prev_output, out_shape) |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
@@ -860,13 +852,11 @@ class PredLogProbs(nn.Cell): |
|
|
|
""" |
|
|
|
def __init__(self, |
|
|
|
batch_size, |
|
|
|
seq_length, |
|
|
|
width, |
|
|
|
compute_type=mstype.float32, |
|
|
|
dtype=mstype.float32): |
|
|
|
super(PredLogProbs, self).__init__() |
|
|
|
self.batch_size = batch_size |
|
|
|
self.seq_length = seq_length |
|
|
|
self.width = width |
|
|
|
self.compute_type = compute_type |
|
|
|
self.dtype = dtype |
|
|
|
@@ -874,14 +864,16 @@ class PredLogProbs(nn.Cell): |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.matmul = P.MatMul(transpose_b=True) |
|
|
|
self.log_softmax = nn.LogSoftmax(axis=-1) |
|
|
|
self.shape_flat_sequence_tensor = (self.batch_size * self.seq_length, self.width) |
|
|
|
self.cast = P.Cast() |
|
|
|
|
|
|
|
def construct(self, |
|
|
|
input_tensor, |
|
|
|
output_weights): |
|
|
|
output_weights, |
|
|
|
seq_length): |
|
|
|
"""Get log probs.""" |
|
|
|
input_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor) |
|
|
|
shape_flat_sequence_tensor = (self.batch_size * seq_length, self.width) |
|
|
|
|
|
|
|
input_tensor = self.reshape(input_tensor, shape_flat_sequence_tensor) |
|
|
|
input_tensor = self.cast(input_tensor, self.compute_type) |
|
|
|
output_weights = self.cast(output_weights, self.compute_type) |
|
|
|
|
|
|
|
@@ -918,7 +910,6 @@ class TransformerDecoderStep(nn.Cell): |
|
|
|
def __init__(self, |
|
|
|
batch_size, |
|
|
|
hidden_size, |
|
|
|
enc_seq_length, |
|
|
|
max_decode_length, |
|
|
|
num_hidden_layers, |
|
|
|
num_attention_heads=16, |
|
|
|
@@ -942,8 +933,6 @@ class TransformerDecoderStep(nn.Cell): |
|
|
|
self.tfm_decoder = TransformerDecoder( |
|
|
|
batch_size=batch_size, |
|
|
|
hidden_size=hidden_size, |
|
|
|
seq_length=-1, # -1 means length is not fixed |
|
|
|
enc_seq_length=enc_seq_length, |
|
|
|
num_attention_heads=num_attention_heads, |
|
|
|
num_hidden_layers=num_hidden_layers, |
|
|
|
intermediate_size=intermediate_size, |
|
|
|
@@ -966,7 +955,7 @@ class TransformerDecoderStep(nn.Cell): |
|
|
|
|
|
|
|
self.cast_compute_type = CastWrapper(dst_type=compute_type) |
|
|
|
|
|
|
|
def construct(self, input_ids, enc_states, enc_attention_mask): |
|
|
|
def construct(self, input_ids, enc_states, enc_attention_mask, seq_length): |
|
|
|
""" |
|
|
|
Multi-layer transformer decoder step. |
|
|
|
input_ids: [batch_size * beam_width] |
|
|
|
@@ -988,17 +977,23 @@ class TransformerDecoderStep(nn.Cell): |
|
|
|
enc_attention_mask = enc_attention_mask[::, 0:input_len:1, ::] |
|
|
|
|
|
|
|
# call TransformerDecoder |
|
|
|
decoder_output = self.tfm_decoder(input_embedding, input_mask, enc_states, enc_attention_mask) |
|
|
|
decoder_output = self.tfm_decoder(input_embedding, input_mask, enc_states, enc_attention_mask, -1, seq_length) |
|
|
|
|
|
|
|
# take the last step |
|
|
|
decoder_output = decoder_output[::, input_len-1:input_len:1, ::] |
|
|
|
|
|
|
|
# projection and log_prob |
|
|
|
log_probs = self.projection(decoder_output, embedding_tables) |
|
|
|
log_probs = self.projection(decoder_output, embedding_tables, 1) |
|
|
|
|
|
|
|
return log_probs |
|
|
|
|
|
|
|
|
|
|
|
@constexpr |
|
|
|
def convert_np_to_tensor_encoder(seq_length): |
|
|
|
ones = np.ones(shape=(seq_length, seq_length)) |
|
|
|
return Tensor(np.tril(ones), dtype=mstype.float32) |
|
|
|
|
|
|
|
|
|
|
|
class TransformerModel(nn.Cell): |
|
|
|
""" |
|
|
|
Transformer with encoder and decoder. |
|
|
|
@@ -1021,12 +1016,13 @@ class TransformerModel(nn.Cell): |
|
|
|
|
|
|
|
self.input_mask_from_dataset = config.input_mask_from_dataset |
|
|
|
self.batch_size = config.batch_size |
|
|
|
self.seq_length = config.seq_length |
|
|
|
self.hidden_size = config.hidden_size |
|
|
|
self.num_hidden_layers = config.num_hidden_layers |
|
|
|
self.embedding_size = config.hidden_size |
|
|
|
|
|
|
|
self.last_idx = self.num_hidden_layers - 1 |
|
|
|
self.beam_width = config.beam_width |
|
|
|
self.max_decode_length = config.max_decode_length |
|
|
|
|
|
|
|
self.tfm_embedding_lookup = EmbeddingLookup( |
|
|
|
vocab_size=config.vocab_size, |
|
|
|
@@ -1048,7 +1044,6 @@ class TransformerModel(nn.Cell): |
|
|
|
self.tfm_encoder = TransformerEncoder( |
|
|
|
batch_size=self.batch_size, |
|
|
|
hidden_size=self.hidden_size, |
|
|
|
seq_length=self.seq_length, |
|
|
|
num_attention_heads=config.num_attention_heads, |
|
|
|
num_hidden_layers=self.num_hidden_layers, |
|
|
|
intermediate_size=config.intermediate_size, |
|
|
|
@@ -1062,15 +1057,12 @@ class TransformerModel(nn.Cell): |
|
|
|
if is_training: |
|
|
|
self.projection = PredLogProbs( |
|
|
|
batch_size=self.batch_size, |
|
|
|
seq_length=self.seq_length, |
|
|
|
width=self.hidden_size, |
|
|
|
compute_type=config.compute_type, |
|
|
|
dtype=config.dtype) |
|
|
|
self.tfm_decoder = TransformerDecoder( |
|
|
|
batch_size=self.batch_size, |
|
|
|
hidden_size=self.hidden_size, |
|
|
|
seq_length=self.seq_length, |
|
|
|
enc_seq_length=self.seq_length, |
|
|
|
num_attention_heads=config.num_attention_heads, |
|
|
|
num_hidden_layers=self.num_hidden_layers, |
|
|
|
intermediate_size=config.intermediate_size, |
|
|
|
@@ -1083,14 +1075,12 @@ class TransformerModel(nn.Cell): |
|
|
|
else: |
|
|
|
self.projection = PredLogProbs( |
|
|
|
batch_size=self.batch_size * config.beam_width, |
|
|
|
seq_length=1, |
|
|
|
width=self.hidden_size, |
|
|
|
compute_type=config.compute_type, |
|
|
|
dtype=config.dtype) |
|
|
|
self.tfm_decoder = TransformerDecoderStep( |
|
|
|
batch_size=self.batch_size * config.beam_width, |
|
|
|
hidden_size=self.hidden_size, |
|
|
|
enc_seq_length=self.seq_length, |
|
|
|
max_decode_length=config.max_decode_length, |
|
|
|
num_hidden_layers=config.num_hidden_layers, |
|
|
|
num_attention_heads=config.num_attention_heads, |
|
|
|
@@ -1113,24 +1103,24 @@ class TransformerModel(nn.Cell): |
|
|
|
length_penalty_weight=config.length_penalty_weight, |
|
|
|
max_decode_length=config.max_decode_length) |
|
|
|
|
|
|
|
self.tfm_decoder.add_flags(loop_can_unroll=True) |
|
|
|
self.tile_beam = TileBeam(beam_width=self.beam_width) |
|
|
|
ones = np.ones(shape=(self.batch_size, self.max_decode_length)) |
|
|
|
self.encdec_mask = Tensor(ones, mstype.float32) |
|
|
|
|
|
|
|
self.cast = P.Cast() |
|
|
|
self.dtype = config.dtype |
|
|
|
self.cast_compute_type = CastWrapper(dst_type=config.compute_type) |
|
|
|
self.expand = P.ExpandDims() |
|
|
|
self.multiply = P.Mul() |
|
|
|
self.shape = P.Shape() |
|
|
|
|
|
|
|
self._create_attention_mask_from_input_mask = CreateAttentionMaskFromInputMask() |
|
|
|
|
|
|
|
if is_training: |
|
|
|
ones = np.ones(shape=(self.seq_length, self.seq_length)) |
|
|
|
self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32) |
|
|
|
else: |
|
|
|
self.tile_beam = TileBeam(beam_width=config.beam_width) |
|
|
|
ones = np.ones(shape=(config.batch_size, config.max_decode_length)) |
|
|
|
self.encdec_mask = Tensor(ones, dtype=mstype.float32) |
|
|
|
|
|
|
|
def construct(self, source_ids, source_mask, target_ids=None, target_mask=None): |
|
|
|
"""Transformer with encoder and decoder.""" |
|
|
|
seq_length = self.shape(source_ids)[1] |
|
|
|
|
|
|
|
# process source sentence |
|
|
|
src_word_embeddings, embedding_tables = self.tfm_embedding_lookup(source_ids) |
|
|
|
src_embedding_output = self.tfm_embedding_postprocessor_for_encoder(src_word_embeddings) |
|
|
|
@@ -1138,21 +1128,24 @@ class TransformerModel(nn.Cell): |
|
|
|
enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask) |
|
|
|
# transformer encoder |
|
|
|
encoder_output = self.tfm_encoder(self.cast_compute_type(src_embedding_output), |
|
|
|
self.cast_compute_type(enc_attention_mask)) |
|
|
|
self.cast_compute_type(enc_attention_mask), |
|
|
|
seq_length) |
|
|
|
|
|
|
|
if self.is_training: |
|
|
|
future_mask = convert_np_to_tensor_encoder(seq_length) |
|
|
|
# process target sentence |
|
|
|
tgt_word_embeddings, _ = self.tfm_embedding_lookup(target_ids) |
|
|
|
tgt_embedding_output = self.tfm_embedding_postprocessor_for_decoder(tgt_word_embeddings) |
|
|
|
# attention mask [batch_size, seq_length, seq_length] |
|
|
|
tgt_attention_mask = self._create_attention_mask_from_input_mask(target_mask) |
|
|
|
tgt_attention_mask = self.multiply(tgt_attention_mask, self.expand(self.future_mask, 0)) |
|
|
|
tgt_attention_mask = self.multiply(tgt_attention_mask, self.expand(future_mask, 0)) |
|
|
|
# transformer decoder |
|
|
|
decoder_output = self.tfm_decoder(self.cast_compute_type(tgt_embedding_output), |
|
|
|
self.cast_compute_type(tgt_attention_mask), |
|
|
|
encoder_output, enc_attention_mask) |
|
|
|
encoder_output, enc_attention_mask, |
|
|
|
seq_length, seq_length) |
|
|
|
# calculate logits and log_probs |
|
|
|
log_probs = self.projection(decoder_output, embedding_tables) |
|
|
|
log_probs = self.projection(decoder_output, embedding_tables, seq_length) |
|
|
|
ret = log_probs |
|
|
|
else: |
|
|
|
beam_encoder_output = self.tile_beam(encoder_output) |
|
|
|
|