|
|
|
@@ -542,7 +542,7 @@ class BertAttention(nn.Cell): |
|
|
|
attention_probs_r = self.reshape( |
|
|
|
attention_probs_t, |
|
|
|
(self.from_seq_length, |
|
|
|
self.batch_num, |
|
|
|
-1, |
|
|
|
self.to_seq_length)) |
|
|
|
# value_position_scores is [F, B * N, H] |
|
|
|
value_position_scores = self.matmul(attention_probs_r, |
|
|
|
@@ -550,7 +550,7 @@ class BertAttention(nn.Cell): |
|
|
|
# value_position_scores_r is [F, B, N, H] |
|
|
|
value_position_scores_r = self.reshape(value_position_scores, |
|
|
|
(self.from_seq_length, |
|
|
|
self.batch_size, |
|
|
|
-1, |
|
|
|
self.num_attention_heads, |
|
|
|
self.size_per_head)) |
|
|
|
# value_position_scores_r_t is [B, N, F, H] |
|
|
|
|