| @@ -844,9 +844,10 @@ class BertModel(nn.Cell): | |||||
| attention_mask) | attention_mask) | ||||
| sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) | sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) | ||||
| # pooler | # pooler | ||||
| batch_size = P.Shape()(input_ids)[0] | |||||
| sequence_slice = self.slice(sequence_output, | sequence_slice = self.slice(sequence_output, | ||||
| (0, 0, 0), | (0, 0, 0), | ||||
| (-1, 1, self.hidden_size), | |||||
| (batch_size, 1, self.hidden_size), | |||||
| (1, 1, 1)) | (1, 1, 1)) | ||||
| first_token = self.squeeze_1(sequence_slice) | first_token = self.squeeze_1(sequence_slice) | ||||
| pooled_output = self.dense(first_token) | pooled_output = self.dense(first_token) | ||||
| @@ -939,9 +940,10 @@ class TinyBertModel(nn.Cell): | |||||
| attention_mask) | attention_mask) | ||||
| sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) | sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) | ||||
| # pooler | # pooler | ||||
| batch_size = P.Shape()(input_ids)[0] | |||||
| sequence_slice = self.slice(sequence_output, | sequence_slice = self.slice(sequence_output, | ||||
| (0, 0, 0), | (0, 0, 0), | ||||
| (-1, 1, self.hidden_size), | |||||
| (batch_size, 1, self.hidden_size), | |||||
| (1, 1, 1)) | (1, 1, 1)) | ||||
| first_token = self.squeeze_1(sequence_slice) | first_token = self.squeeze_1(sequence_slice) | ||||
| pooled_output = self.dense(first_token) | pooled_output = self.dense(first_token) | ||||