|
|
|
@@ -20,11 +20,11 @@ import numpy as np |
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
import mindspore.nn as nn |
|
|
|
import mindspore.ops.functional as F |
|
|
|
from mindspore.common.initializer import TruncatedNormal, initializer |
|
|
|
from mindspore.ops import operations as P |
|
|
|
from mindspore.common.tensor import Tensor |
|
|
|
from mindspore.common.parameter import Parameter |
|
|
|
from .beam_search import BeamSearchDecoder, TileBeam |
|
|
|
from .weight_init import normal_weight, weight_variable |
|
|
|
|
|
|
|
class TransformerConfig: |
|
|
|
""" |
|
|
|
@@ -118,9 +118,7 @@ class EmbeddingLookup(nn.Cell): |
|
|
|
self.vocab_size = vocab_size |
|
|
|
self.embedding_size = embedding_size |
|
|
|
self.use_one_hot_embeddings = use_one_hot_embeddings |
|
|
|
self.embedding_table = Parameter(initializer |
|
|
|
(TruncatedNormal(initializer_range), |
|
|
|
[vocab_size, embedding_size]), |
|
|
|
self.embedding_table = Parameter(normal_weight([vocab_size, embedding_size], embedding_size), |
|
|
|
name='embedding_table') |
|
|
|
self.expand = P.ExpandDims() |
|
|
|
self.shape_flat = (-1,) |
|
|
|
@@ -138,8 +136,7 @@ class EmbeddingLookup(nn.Cell): |
|
|
|
flat_ids = self.reshape(input_ids, self.shape_flat) |
|
|
|
if self.use_one_hot_embeddings: |
|
|
|
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value) |
|
|
|
output_for_reshape = self.array_mul( |
|
|
|
one_hot_ids, self.embedding_table) |
|
|
|
output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table) |
|
|
|
else: |
|
|
|
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) |
|
|
|
|
|
|
|
@@ -329,22 +326,22 @@ class MultiheadAttention(nn.Cell): |
|
|
|
units, |
|
|
|
activation=query_act, |
|
|
|
has_bias=False, |
|
|
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) |
|
|
|
weight_init=weight_variable([units, from_tensor_width])).to_float(compute_type) |
|
|
|
self.key_layer = nn.Dense(to_tensor_width, |
|
|
|
units, |
|
|
|
activation=key_act, |
|
|
|
has_bias=False, |
|
|
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) |
|
|
|
weight_init=weight_variable([units, to_tensor_width])).to_float(compute_type) |
|
|
|
self.value_layer = nn.Dense(to_tensor_width, |
|
|
|
units, |
|
|
|
activation=value_act, |
|
|
|
has_bias=False, |
|
|
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) |
|
|
|
weight_init=weight_variable([units, to_tensor_width])).to_float(compute_type) |
|
|
|
self.out_layer = nn.Dense(units, |
|
|
|
out_tensor_width, |
|
|
|
activation=out_act, |
|
|
|
has_bias=False, |
|
|
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) |
|
|
|
weight_init=weight_variable([out_tensor_width, units])).to_float(compute_type) |
|
|
|
|
|
|
|
self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head) |
|
|
|
self.shape_to = (batch_size, to_seq_length, num_attention_heads, size_per_head) |
|
|
|
@@ -518,10 +515,10 @@ class FeedForward(nn.Cell): |
|
|
|
self.conv1 = nn.Dense(in_channels, |
|
|
|
hidden_size, |
|
|
|
activation=hidden_act, |
|
|
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) |
|
|
|
weight_init=weight_variable([hidden_size, in_channels])).to_float(compute_type) |
|
|
|
self.conv2 = nn.Dense(hidden_size, |
|
|
|
out_channels, |
|
|
|
weight_init=TruncatedNormal(initializer_range)).to_float(compute_type) |
|
|
|
weight_init=weight_variable([out_channels, hidden_size])).to_float(compute_type) |
|
|
|
|
|
|
|
self.preprocess = LayerPreprocess(in_channels=in_channels) |
|
|
|
self.postprocess = LayerPostprocess(dropout_prob=hidden_dropout_prob) |
|
|
|
@@ -1108,7 +1105,13 @@ class TransformerModel(nn.Cell): |
|
|
|
embedding_size=self.embedding_size, |
|
|
|
use_one_hot_embeddings=use_one_hot_embeddings, |
|
|
|
initializer_range=config.initializer_range) |
|
|
|
self.tfm_embedding_postprocessor = EmbeddingPostprocessor( |
|
|
|
self.tfm_embedding_postprocessor_for_encoder = EmbeddingPostprocessor( |
|
|
|
embedding_size=self.embedding_size, |
|
|
|
use_one_hot_embeddings=use_one_hot_embeddings, |
|
|
|
initializer_range=0.02, |
|
|
|
max_position_embeddings=config.max_position_embeddings, |
|
|
|
dropout_prob=config.hidden_dropout_prob) |
|
|
|
self.tfm_embedding_postprocessor_for_decoder = EmbeddingPostprocessor( |
|
|
|
embedding_size=self.embedding_size, |
|
|
|
use_one_hot_embeddings=use_one_hot_embeddings, |
|
|
|
initializer_range=0.02, |
|
|
|
@@ -1171,7 +1174,7 @@ class TransformerModel(nn.Cell): |
|
|
|
hidden_act=config.hidden_act, |
|
|
|
compute_type=config.compute_type, |
|
|
|
embedding_lookup=self.tfm_embedding_lookup, |
|
|
|
embedding_processor=self.tfm_embedding_postprocessor, |
|
|
|
embedding_processor=self.tfm_embedding_postprocessor_for_decoder, |
|
|
|
projection=self.projection) |
|
|
|
self.tfm_decoder = BeamSearchDecoder( |
|
|
|
batch_size=config.batch_size, |
|
|
|
@@ -1195,15 +1198,14 @@ class TransformerModel(nn.Cell): |
|
|
|
ones = np.ones(shape=(self.seq_length, self.seq_length)) |
|
|
|
self.future_mask = Tensor(np.tril(ones), dtype=mstype.float32) |
|
|
|
else: |
|
|
|
self.tile_beam = TileBeam( |
|
|
|
beam_width=config.beam_width) |
|
|
|
self.tile_beam = TileBeam(beam_width=config.beam_width) |
|
|
|
ones = np.ones(shape=(config.batch_size, config.max_decode_length)) |
|
|
|
self.encdec_mask = Tensor(ones, dtype=mstype.float32) |
|
|
|
|
|
|
|
def construct(self, source_ids, source_mask, target_ids=None, target_mask=None): |
|
|
|
# process source sentence |
|
|
|
src_word_embeddings, embedding_tables = self.tfm_embedding_lookup(source_ids) |
|
|
|
src_embedding_output = self.tfm_embedding_postprocessor(src_word_embeddings) |
|
|
|
src_embedding_output = self.tfm_embedding_postprocessor_for_encoder(src_word_embeddings) |
|
|
|
# attention mask [batch_size, seq_length, seq_length] |
|
|
|
enc_attention_mask = self._create_attention_mask_from_input_mask(source_mask) |
|
|
|
# transformer encoder |
|
|
|
@@ -1213,7 +1215,7 @@ class TransformerModel(nn.Cell): |
|
|
|
if self.is_training: |
|
|
|
# process target sentence |
|
|
|
tgt_word_embeddings, _ = self.tfm_embedding_lookup(target_ids) |
|
|
|
tgt_embedding_output = self.tfm_embedding_postprocessor(tgt_word_embeddings) |
|
|
|
tgt_embedding_output = self.tfm_embedding_postprocessor_for_decoder(tgt_word_embeddings) |
|
|
|
# attention mask [batch_size, seq_length, seq_length] |
|
|
|
tgt_attention_mask = self._create_attention_mask_from_input_mask(target_mask) |
|
|
|
tgt_attention_mask = self.multiply(tgt_attention_mask, self.expand(self.future_mask, 0)) |
|
|
|
@@ -1223,15 +1225,14 @@ class TransformerModel(nn.Cell): |
|
|
|
encoder_output, enc_attention_mask) |
|
|
|
# calculate logits and log_probs |
|
|
|
log_probs = self.projection(decoder_output, embedding_tables) |
|
|
|
return log_probs |
|
|
|
|
|
|
|
beam_encoder_output = self.tile_beam(encoder_output) |
|
|
|
ret = log_probs |
|
|
|
else: |
|
|
|
beam_encoder_output = self.tile_beam(encoder_output) |
|
|
|
|
|
|
|
enc_attention_mask = self.multiply( |
|
|
|
enc_attention_mask[::, 0:1:1, ::], |
|
|
|
self.expand(self.encdec_mask, -1)) |
|
|
|
enc_attention_mask = self.multiply(enc_attention_mask[::, 0:1:1, ::], self.expand(self.encdec_mask, -1)) |
|
|
|
|
|
|
|
beam_enc_attention_mask = self.tile_beam(enc_attention_mask) |
|
|
|
beam_enc_attention_mask = self.cast_compute_type(beam_enc_attention_mask) |
|
|
|
predicted_ids = self.tfm_decoder(beam_encoder_output, beam_enc_attention_mask) |
|
|
|
return predicted_ids |
|
|
|
beam_enc_attention_mask = self.tile_beam(enc_attention_mask) |
|
|
|
beam_enc_attention_mask = self.cast_compute_type(beam_enc_attention_mask) |
|
|
|
predicted_ids = self.tfm_decoder(beam_encoder_output, beam_enc_attention_mask) |
|
|
|
ret = predicted_ids |
|
|
|
return ret |