| @@ -194,7 +194,7 @@ class EmbeddingPostprocessor(nn.Cell): | |||||
| self.dropout = nn.Dropout(1 - dropout_prob) | self.dropout = nn.Dropout(1 - dropout_prob) | ||||
| self.gather = P.GatherV2() | self.gather = P.GatherV2() | ||||
| self.use_relative_positions = use_relative_positions | self.use_relative_positions = use_relative_positions | ||||
| self.slice = P.Slice() | |||||
| self.slice = P.StridedSlice() | |||||
| self.full_position_embeddings = Parameter(initializer | self.full_position_embeddings = Parameter(initializer | ||||
| (TruncatedNormal(initializer_range), | (TruncatedNormal(initializer_range), | ||||
| [max_position_embeddings, | [max_position_embeddings, | ||||
| @@ -216,7 +216,7 @@ class EmbeddingPostprocessor(nn.Cell): | |||||
| output += token_type_embeddings | output += token_type_embeddings | ||||
| if not self.use_relative_positions: | if not self.use_relative_positions: | ||||
| _, seq, width = self.shape | _, seq, width = self.shape | ||||
| position_embeddings = self.slice(self.full_position_embeddings, [0, 0], [seq, width]) | |||||
| position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1)) | |||||
| position_embeddings = self.reshape(position_embeddings, (1, seq, width)) | position_embeddings = self.reshape(position_embeddings, (1, seq, width)) | ||||
| output += position_embeddings | output += position_embeddings | ||||
| output = self.layernorm(output) | output = self.layernorm(output) | ||||