From b0fc7b7289f37529aa9ae92e64355e4802c694e8 Mon Sep 17 00:00:00 2001 From: yoonlee666 Date: Fri, 10 Apr 2020 10:51:54 +0800 Subject: [PATCH] change op Slice to StridedSlice in bert model --- mindspore/model_zoo/Bert_NEZHA/bert_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/model_zoo/Bert_NEZHA/bert_model.py b/mindspore/model_zoo/Bert_NEZHA/bert_model.py index d7f9355b3c..b9c6e8c4a1 100644 --- a/mindspore/model_zoo/Bert_NEZHA/bert_model.py +++ b/mindspore/model_zoo/Bert_NEZHA/bert_model.py @@ -194,7 +194,7 @@ class EmbeddingPostprocessor(nn.Cell): self.dropout = nn.Dropout(1 - dropout_prob) self.gather = P.GatherV2() self.use_relative_positions = use_relative_positions - self.slice = P.Slice() + self.slice = P.StridedSlice() self.full_position_embeddings = Parameter(initializer (TruncatedNormal(initializer_range), [max_position_embeddings, @@ -216,7 +216,7 @@ class EmbeddingPostprocessor(nn.Cell): output += token_type_embeddings if not self.use_relative_positions: _, seq, width = self.shape - position_embeddings = self.slice(self.full_position_embeddings, [0, 0], [seq, width]) + position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1)) position_embeddings = self.reshape(position_embeddings, (1, seq, width)) output += position_embeddings output = self.layernorm(output)