|
|
|
@@ -131,6 +131,7 @@ class EmbeddingLookup(nn.Cell): |
|
|
|
self.shape = P.Shape() |
|
|
|
|
|
|
|
def construct(self, input_ids): |
|
|
|
"""Get a embeddings lookup table with a fixed dictionary and size.""" |
|
|
|
input_shape = self.shape(input_ids) |
|
|
|
|
|
|
|
flat_ids = self.reshape(input_ids, self.shape_flat) |
|
|
|
@@ -200,6 +201,7 @@ class EmbeddingPostprocessor(nn.Cell): |
|
|
|
self.shape = P.Shape() |
|
|
|
|
|
|
|
def construct(self, word_embeddings): |
|
|
|
"""Postprocessors apply positional embeddings to word embeddings.""" |
|
|
|
input_shape = self.shape(word_embeddings) |
|
|
|
input_len = input_shape[1] |
|
|
|
|
|
|
|
@@ -377,7 +379,7 @@ class MultiheadAttention(nn.Cell): |
|
|
|
self.softmax_cast = P.Cast() |
|
|
|
|
|
|
|
def construct(self, from_tensor, to_tensor, attention_mask=None): |
|
|
|
# reshape 2d/3d input tensors to 2d |
|
|
|
"""reshape 2d/3d input tensors to 2d""" |
|
|
|
from_tensor_2d = self.reshape(from_tensor, self.shape_from_2d) |
|
|
|
to_tensor_2d = self.reshape(to_tensor, self.shape_to_2d) |
|
|
|
query_out = self.query_layer(from_tensor_2d) |
|
|
|
@@ -476,6 +478,7 @@ class SelfAttention(nn.Cell): |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.shape = (-1, hidden_size) |
|
|
|
def construct(self, input_tensor, memory_tensor, attention_mask): |
|
|
|
"""Apply self-attention.""" |
|
|
|
input_tensor = self.reshape(input_tensor, self.shape) |
|
|
|
memory_tensor = self.reshape(memory_tensor, self.shape) |
|
|
|
|
|
|
|
@@ -831,6 +834,7 @@ class CreateAttentionMaskFromInputMask(nn.Cell): |
|
|
|
self.batch_matmul = P.BatchMatMul() |
|
|
|
|
|
|
|
def construct(self, input_mask): |
|
|
|
"""Create attention mask according to input mask.""" |
|
|
|
input_shape = self.shape(input_mask) |
|
|
|
shape_right = (input_shape[0], 1, input_shape[1]) |
|
|
|
shape_left = input_shape + (1,) |
|
|
|
@@ -876,6 +880,7 @@ class PredLogProbs(nn.Cell): |
|
|
|
def construct(self, |
|
|
|
input_tensor, |
|
|
|
output_weights): |
|
|
|
"""Get log probs.""" |
|
|
|
input_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor) |
|
|
|
input_tensor = self.cast(input_tensor, self.compute_type) |
|
|
|
output_weights = self.cast(output_weights, self.compute_type) |
|
|
|
@@ -962,7 +967,10 @@ class TransformerDecoderStep(nn.Cell): |
|
|
|
self.cast_compute_type = CastWrapper(dst_type=compute_type) |
|
|
|
|
|
|
|
def construct(self, input_ids, enc_states, enc_attention_mask): |
|
|
|
# input_ids: [batch_size * beam_width] |
|
|
|
""" |
|
|
|
Multi-layer transformer decoder step. |
|
|
|
input_ids: [batch_size * beam_width] |
|
|
|
""" |
|
|
|
# process embedding |
|
|
|
input_embedding, embedding_tables = self.tfm_embedding_lookup(input_ids) |
|
|
|
input_embedding = self.tfm_embedding_processor(input_embedding) |
|
|
|
@@ -1122,6 +1130,7 @@ class TransformerModel(nn.Cell): |
|
|
|
self.encdec_mask = Tensor(ones, dtype=mstype.float32) |
|
|
|
|
|
|
|
def construct(self, source_ids, source_mask, target_ids=None, target_mask=None): |
|
|
|
"""Transformer with encoder and decoder.""" |
|
|
|
# process source sentence |
|
|
|
src_word_embeddings, embedding_tables = self.tfm_embedding_lookup(source_ids) |
|
|
|
src_embedding_output = self.tfm_embedding_postprocessor_for_encoder(src_word_embeddings) |
|
|
|
|