|
|
|
@@ -194,7 +194,7 @@ class EmbeddingPostprocessor(nn.Cell): |
|
|
|
self.dropout = nn.Dropout(1 - dropout_prob) |
|
|
|
self.gather = P.GatherV2() |
|
|
|
self.use_relative_positions = use_relative_positions |
|
|
|
self.slice = P.Slice() |
|
|
|
self.slice = P.StridedSlice() |
|
|
|
self.full_position_embeddings = Parameter(initializer |
|
|
|
(TruncatedNormal(initializer_range), |
|
|
|
[max_position_embeddings, |
|
|
|
@@ -216,7 +216,7 @@ class EmbeddingPostprocessor(nn.Cell): |
|
|
|
output += token_type_embeddings |
|
|
|
if not self.use_relative_positions: |
|
|
|
_, seq, width = self.shape |
|
|
|
position_embeddings = self.slice(self.full_position_embeddings, [0, 0], [seq, width]) |
|
|
|
position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1)) |
|
|
|
position_embeddings = self.reshape(position_embeddings, (1, seq, width)) |
|
|
|
output += position_embeddings |
|
|
|
output = self.layernorm(output) |
|
|
|
|