You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Layers.py 1.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. ''' Define the Layers '''
  2. import torch.nn as nn
  3. from transformer.SubLayers import MultiHeadAttention, PositionwiseFeedForward
  4. __author__ = "Yu-Hsiang Huang"
  5. class EncoderLayer(nn.Module):
  6. ''' Compose with two layers '''
  7. def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
  8. super(EncoderLayer, self).__init__()
  9. self.slf_attn = MultiHeadAttention(
  10. n_head, d_model, d_k, d_v, dropout=dropout)
  11. self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
  12. def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
  13. enc_output, enc_slf_attn = self.slf_attn(
  14. enc_input, enc_input, enc_input, mask=slf_attn_mask)
  15. enc_output *= non_pad_mask
  16. enc_output = self.pos_ffn(enc_output)
  17. enc_output *= non_pad_mask
  18. return enc_output, enc_slf_attn
  19. class DecoderLayer(nn.Module):
  20. ''' Compose with three layers '''
  21. def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
  22. super(DecoderLayer, self).__init__()
  23. self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
  24. self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
  25. self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
  26. def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None):
  27. dec_output, dec_slf_attn = self.slf_attn(
  28. dec_input, dec_input, dec_input, mask=slf_attn_mask)
  29. dec_output *= non_pad_mask
  30. dec_output, dec_enc_attn = self.enc_attn(
  31. dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
  32. dec_output *= non_pad_mask
  33. dec_output = self.pos_ffn(dec_output)
  34. dec_output *= non_pad_mask
  35. return dec_output, dec_slf_attn, dec_enc_attn