|
|
|
@@ -14,7 +14,7 @@ class DotAttention(nn.Module): |
|
|
|
""" |
|
|
|
TODO |
|
|
|
""" |
|
|
|
def __init__(self, key_size, value_size, dropout=0.1): |
|
|
|
def __init__(self, key_size, value_size, dropout=0): |
|
|
|
super(DotAttention, self).__init__() |
|
|
|
self.key_size = key_size |
|
|
|
self.value_size = value_size |
|
|
|
@@ -25,14 +25,14 @@ class DotAttention(nn.Module): |
|
|
|
def forward(self, Q, K, V, mask_out=None): |
|
|
|
""" |
|
|
|
|
|
|
|
:param Q: [batch, seq_len, key_size] |
|
|
|
:param K: [batch, seq_len, key_size] |
|
|
|
:param V: [batch, seq_len, value_size] |
|
|
|
:param mask_out: [batch, seq_len] |
|
|
|
:param Q: [batch, seq_len_q, key_size] |
|
|
|
:param K: [batch, seq_len_k, key_size] |
|
|
|
:param V: [batch, seq_len_k, value_size] |
|
|
|
:param mask_out: [batch, 1, seq_len] or [batch, seq_len_q, seq_len_k] |
|
|
|
""" |
|
|
|
output = torch.matmul(Q, K.transpose(1, 2)) / self.scale |
|
|
|
if mask_out is not None: |
|
|
|
output.masked_fill_(mask_out, -float('inf')) |
|
|
|
output.masked_fill_(mask_out, -1e8) |
|
|
|
output = self.softmax(output) |
|
|
|
output = self.drop(output) |
|
|
|
return torch.matmul(output, V) |
|
|
|
@@ -58,7 +58,8 @@ class MultiHeadAttention(nn.Module): |
|
|
|
self.q_in = nn.Linear(input_size, in_size) |
|
|
|
self.k_in = nn.Linear(input_size, in_size) |
|
|
|
self.v_in = nn.Linear(input_size, in_size) |
|
|
|
self.attention = DotAttention(key_size=key_size, value_size=value_size) |
|
|
|
# follow the paper, do not apply dropout within dot-product |
|
|
|
self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=0) |
|
|
|
self.out = nn.Linear(value_size * num_head, input_size) |
|
|
|
self.drop = TimestepDropout(dropout) |
|
|
|
self.reset_parameters() |
|
|
|
@@ -73,28 +74,29 @@ class MultiHeadAttention(nn.Module): |
|
|
|
def forward(self, Q, K, V, atte_mask_out=None): |
|
|
|
""" |
|
|
|
|
|
|
|
:param Q: [batch, seq_len, model_size] |
|
|
|
:param K: [batch, seq_len, model_size] |
|
|
|
:param V: [batch, seq_len, model_size] |
|
|
|
:param Q: [batch, seq_len_q, model_size] |
|
|
|
:param K: [batch, seq_len_k, model_size] |
|
|
|
:param V: [batch, seq_len_k, model_size] |
|
|
|
:param seq_mask: [batch, seq_len] |
|
|
|
""" |
|
|
|
batch, seq_len, _ = Q.size() |
|
|
|
batch, sq, _ = Q.size() |
|
|
|
sk = K.size(1) |
|
|
|
d_k, d_v, n_head = self.key_size, self.value_size, self.num_head |
|
|
|
# input linear |
|
|
|
q = self.q_in(Q).view(batch, seq_len, n_head, d_k) |
|
|
|
k = self.k_in(K).view(batch, seq_len, n_head, d_k) |
|
|
|
v = self.v_in(V).view(batch, seq_len, n_head, d_k) |
|
|
|
q = self.q_in(Q).view(batch, sq, n_head, d_k) |
|
|
|
k = self.k_in(K).view(batch, sk, n_head, d_k) |
|
|
|
v = self.v_in(V).view(batch, sk, n_head, d_v) |
|
|
|
|
|
|
|
# transpose q, k and v to do batch attention |
|
|
|
q = q.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k) |
|
|
|
k = k.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_k) |
|
|
|
v = v.permute(2, 0, 1, 3).contiguous().view(-1, seq_len, d_v) |
|
|
|
q = q.permute(2, 0, 1, 3).contiguous().view(-1, sq, d_k) |
|
|
|
k = k.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_k) |
|
|
|
v = v.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_v) |
|
|
|
if atte_mask_out is not None: |
|
|
|
atte_mask_out = atte_mask_out.repeat(n_head, 1, 1) |
|
|
|
atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, seq_len, d_v) |
|
|
|
atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, sq, d_v) |
|
|
|
|
|
|
|
# concat all heads, do output linear |
|
|
|
atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, seq_len, -1) |
|
|
|
atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1) |
|
|
|
output = self.drop(self.out(atte)) |
|
|
|
return output |
|
|
|
|
|
|
|
|