Browse Source

!213 use StridedSlice instead of Slice in absolute position embedding code in bert model

Merge pull request !213 from yoonlee666/master
tags/v0.2.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
865b8a79db
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      mindspore/model_zoo/Bert_NEZHA/bert_model.py

+ 2
- 2
mindspore/model_zoo/Bert_NEZHA/bert_model.py View File

@@ -194,7 +194,7 @@ class EmbeddingPostprocessor(nn.Cell):
self.dropout = nn.Dropout(1 - dropout_prob)
self.gather = P.GatherV2()
self.use_relative_positions = use_relative_positions
self.slice = P.Slice()
self.slice = P.StridedSlice()
self.full_position_embeddings = Parameter(initializer
(TruncatedNormal(initializer_range),
[max_position_embeddings,
@@ -216,7 +216,7 @@ class EmbeddingPostprocessor(nn.Cell):
output += token_type_embeddings
if not self.use_relative_positions:
_, seq, width = self.shape
position_embeddings = self.slice(self.full_position_embeddings, [0, 0], [seq, width])
position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1))
position_embeddings = self.reshape(position_embeddings, (1, seq, width))
output += position_embeddings
output = self.layernorm(output)


Loading…
Cancel
Save