|
|
|
@@ -51,8 +51,6 @@ class MultiheadAttention(nn.Module): |
|
|
|
if mask is not None: |
|
|
|
''' |
|
|
|
mask: [batch size, num_heads, seq_len, seq_len] |
|
|
|
mask后两维(seq_len, seq_len)矩阵来看,其中有的行可能都是true(1),对应句子中<pad>位看的行 |
|
|
|
导致softmax后该行的每个位置的attn prob都为1/n而非0,所以此处需重置为0 |
|
|
|
|
|
|
|
>>> F.softmax([-1e10, -100, -100]) |
|
|
|
>>> [0.00, 0.50, 0.50] |
|
|
|
|