|
|
|
@@ -844,9 +844,10 @@ class BertModel(nn.Cell): |
|
|
|
attention_mask) |
|
|
|
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) |
|
|
|
# pooler |
|
|
|
batch_size = P.Shape()(input_ids)[0] |
|
|
|
sequence_slice = self.slice(sequence_output, |
|
|
|
(0, 0, 0), |
|
|
|
(-1, 1, self.hidden_size), |
|
|
|
(batch_size, 1, self.hidden_size), |
|
|
|
(1, 1, 1)) |
|
|
|
first_token = self.squeeze_1(sequence_slice) |
|
|
|
pooled_output = self.dense(first_token) |
|
|
|
@@ -939,9 +940,10 @@ class TinyBertModel(nn.Cell): |
|
|
|
attention_mask) |
|
|
|
sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) |
|
|
|
# pooler |
|
|
|
batch_size = P.Shape()(input_ids)[0] |
|
|
|
sequence_slice = self.slice(sequence_output, |
|
|
|
(0, 0, 0), |
|
|
|
(-1, 1, self.hidden_size), |
|
|
|
(batch_size, 1, self.hidden_size), |
|
|
|
(1, 1, 1)) |
|
|
|
first_token = self.squeeze_1(sequence_slice) |
|
|
|
pooled_output = self.dense(first_token) |
|
|
|
|