From a2af684dd456ba591a805d5fc761572fb3157958 Mon Sep 17 00:00:00 2001 From: yoonlee666 Date: Fri, 25 Sep 2020 18:02:26 +0800 Subject: [PATCH] bugfix tinybert --- model_zoo/official/nlp/tinybert/src/tinybert_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/model_zoo/official/nlp/tinybert/src/tinybert_model.py b/model_zoo/official/nlp/tinybert/src/tinybert_model.py index 5e8dc8436b..09504abcd8 100644 --- a/model_zoo/official/nlp/tinybert/src/tinybert_model.py +++ b/model_zoo/official/nlp/tinybert/src/tinybert_model.py @@ -844,9 +844,10 @@ class BertModel(nn.Cell): attention_mask) sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) # pooler + batch_size = P.Shape()(input_ids)[0] sequence_slice = self.slice(sequence_output, (0, 0, 0), - (-1, 1, self.hidden_size), + (batch_size, 1, self.hidden_size), (1, 1, 1)) first_token = self.squeeze_1(sequence_slice) pooled_output = self.dense(first_token) @@ -939,9 +940,10 @@ class TinyBertModel(nn.Cell): attention_mask) sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) # pooler + batch_size = P.Shape()(input_ids)[0] sequence_slice = self.slice(sequence_output, (0, 0, 0), - (-1, 1, self.hidden_size), + (batch_size, 1, self.hidden_size), (1, 1, 1)) first_token = self.squeeze_1(sequence_slice) pooled_output = self.dense(first_token)