|
|
|
@@ -165,6 +165,7 @@ class EmbeddingPostprocessor(nn.Cell): |
|
|
|
def __init__(self, |
|
|
|
embedding_size, |
|
|
|
embedding_shape, |
|
|
|
use_relative_positions=False, |
|
|
|
use_token_type=False, |
|
|
|
token_type_vocab_size=16, |
|
|
|
use_one_hot_embeddings=False, |
|
|
|
@@ -192,6 +193,13 @@ class EmbeddingPostprocessor(nn.Cell): |
|
|
|
self.layernorm = nn.LayerNorm(embedding_size) |
|
|
|
self.dropout = nn.Dropout(1 - dropout_prob) |
|
|
|
self.gather = P.GatherV2() |
|
|
|
self.use_relative_positions = use_relative_positions |
|
|
|
self.slice = P.Slice() |
|
|
|
self.full_position_embeddings = Parameter(initializer |
|
|
|
(TruncatedNormal(initializer_range), |
|
|
|
[max_position_embeddings, |
|
|
|
embedding_size]), |
|
|
|
name='full_position_embeddings') |
|
|
|
|
|
|
|
def construct(self, token_type_ids, word_embeddings): |
|
|
|
output = word_embeddings |
|
|
|
@@ -206,6 +214,11 @@ class EmbeddingPostprocessor(nn.Cell): |
|
|
|
token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0) |
|
|
|
token_type_embeddings = self.reshape(token_type_embeddings, self.shape) |
|
|
|
output += token_type_embeddings |
|
|
|
if not self.use_relative_positions: |
|
|
|
_, seq, width = self.shape |
|
|
|
position_embeddings = self.slice(self.full_position_embeddings, [0, 0], [seq, width]) |
|
|
|
position_embeddings = self.reshape(position_embeddings, (1, seq, width)) |
|
|
|
output += position_embeddings |
|
|
|
output = self.layernorm(output) |
|
|
|
output = self.dropout(output) |
|
|
|
return output |
|
|
|
@@ -853,6 +866,7 @@ class BertModel(nn.Cell): |
|
|
|
self.bert_embedding_postprocessor = EmbeddingPostprocessor( |
|
|
|
embedding_size=self.embedding_size, |
|
|
|
embedding_shape=output_embedding_shape, |
|
|
|
use_relative_positions=config.use_relative_positions, |
|
|
|
use_token_type=True, |
|
|
|
token_type_vocab_size=config.type_vocab_size, |
|
|
|
use_one_hot_embeddings=use_one_hot_embeddings, |
|
|
|
|