You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Modules.py 729 B

12345678910111213141516171819202122232425262728
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. __author__ = "Yu-Hsiang Huang"
  5. class ScaledDotProductAttention(nn.Module):
  6. ''' Scaled Dot-Product Attention '''
  7. def __init__(self, temperature, attn_dropout=0.1):
  8. super().__init__()
  9. self.temperature = temperature
  10. self.dropout = nn.Dropout(attn_dropout)
  11. self.softmax = nn.Softmax(dim=2)
  12. def forward(self, q, k, v, mask=None):
  13. attn = torch.bmm(q, k.transpose(1, 2))
  14. attn = attn / self.temperature
  15. if mask is not None:
  16. attn = attn.masked_fill(mask, -np.inf)
  17. attn = self.softmax(attn)
  18. attn = self.dropout(attn)
  19. output = torch.bmm(attn, v)
  20. return output, attn