Browse Source

!6895 bugfix tinybert

Merge pull request !6895 from yoonlee666/master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
05e55ef467
1 changed files with 4 additions and 2 deletions
  1. +4
    -2
      model_zoo/official/nlp/tinybert/src/tinybert_model.py

+ 4
- 2
model_zoo/official/nlp/tinybert/src/tinybert_model.py View File

@@ -844,9 +844,10 @@ class BertModel(nn.Cell):
attention_mask)
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
# pooler
batch_size = P.Shape()(input_ids)[0]
sequence_slice = self.slice(sequence_output,
(0, 0, 0),
(-1, 1, self.hidden_size),
(batch_size, 1, self.hidden_size),
(1, 1, 1))
first_token = self.squeeze_1(sequence_slice)
pooled_output = self.dense(first_token)
@@ -939,9 +940,10 @@ class TinyBertModel(nn.Cell):
attention_mask)
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype)
# pooler
batch_size = P.Shape()(input_ids)[0]
sequence_slice = self.slice(sequence_output,
(0, 0, 0),
(-1, 1, self.hidden_size),
(batch_size, 1, self.hidden_size),
(1, 1, 1))
first_token = self.squeeze_1(sequence_slice)
pooled_output = self.dense(first_token)


Loading…
Cancel
Save