You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Models.py 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. ''' Define the Transformer model '''
  2. import torch
  3. import torch.nn as nn
  4. import numpy as np
  5. import transformer.Constants as Constants
  6. from transformer.Layers import EncoderLayer, DecoderLayer
  7. __author__ = "Yu-Hsiang Huang"
  8. def get_non_pad_mask(seq):
  9. assert seq.dim() == 2
  10. return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1)
  11. def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
  12. ''' Sinusoid position encoding table '''
  13. def cal_angle(position, hid_idx):
  14. return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
  15. def get_posi_angle_vec(position):
  16. return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
  17. sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
  18. sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
  19. sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
  20. if padding_idx is not None:
  21. # zero vector for padding dimension
  22. sinusoid_table[padding_idx] = 0.
  23. return torch.FloatTensor(sinusoid_table)
  24. def get_attn_key_pad_mask(seq_k, seq_q):
  25. ''' For masking out the padding part of key sequence. '''
  26. # Expand to fit the shape of key query attention matrix.
  27. len_q = seq_q.size(1)
  28. padding_mask = seq_k.eq(Constants.PAD)
  29. padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk
  30. return padding_mask
  31. def get_subsequent_mask(seq):
  32. ''' For masking out the subsequent info. '''
  33. sz_b, len_s = seq.size()
  34. subsequent_mask = torch.triu(
  35. torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
  36. subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls
  37. return subsequent_mask
  38. class Encoder(nn.Module):
  39. ''' A encoder model with self attention mechanism. '''
  40. def __init__(
  41. self,
  42. n_src_vocab, len_max_seq, d_word_vec,
  43. n_layers, n_head, d_k, d_v,
  44. d_model, d_inner, dropout=0.1):
  45. super().__init__()
  46. n_position = len_max_seq + 1
  47. self.src_word_emb = nn.Embedding(
  48. n_src_vocab, d_word_vec, padding_idx=Constants.PAD)
  49. self.position_enc = nn.Embedding.from_pretrained(
  50. get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
  51. freeze=True)
  52. self.layer_stack = nn.ModuleList([
  53. EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
  54. for _ in range(n_layers)])
  55. def forward(self, src_seq, src_pos, return_attns=False):
  56. enc_slf_attn_list = []
  57. # -- Prepare masks
  58. slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq)
  59. non_pad_mask = get_non_pad_mask(src_seq)
  60. # -- Forward
  61. enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos)
  62. for enc_layer in self.layer_stack:
  63. enc_output, enc_slf_attn = enc_layer(
  64. enc_output,
  65. non_pad_mask=non_pad_mask,
  66. slf_attn_mask=slf_attn_mask)
  67. if return_attns:
  68. enc_slf_attn_list += [enc_slf_attn]
  69. if return_attns:
  70. return enc_output, enc_slf_attn_list
  71. return enc_output,
  72. class Decoder(nn.Module):
  73. ''' A decoder model with self attention mechanism. '''
  74. def __init__(
  75. self,
  76. n_tgt_vocab, len_max_seq, d_word_vec,
  77. n_layers, n_head, d_k, d_v,
  78. d_model, d_inner, dropout=0.1):
  79. super().__init__()
  80. n_position = len_max_seq + 1
  81. self.tgt_word_emb = nn.Embedding(
  82. n_tgt_vocab, d_word_vec, padding_idx=Constants.PAD)
  83. self.position_enc = nn.Embedding.from_pretrained(
  84. get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
  85. freeze=True)
  86. self.layer_stack = nn.ModuleList([
  87. DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
  88. for _ in range(n_layers)])
  89. def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False):
  90. dec_slf_attn_list, dec_enc_attn_list = [], []
  91. # -- Prepare masks
  92. non_pad_mask = get_non_pad_mask(tgt_seq)
  93. slf_attn_mask_subseq = get_subsequent_mask(tgt_seq)
  94. slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq)
  95. slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)
  96. dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq)
  97. # -- Forward
  98. dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos)
  99. for dec_layer in self.layer_stack:
  100. dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
  101. dec_output, enc_output,
  102. non_pad_mask=non_pad_mask,
  103. slf_attn_mask=slf_attn_mask,
  104. dec_enc_attn_mask=dec_enc_attn_mask)
  105. if return_attns:
  106. dec_slf_attn_list += [dec_slf_attn]
  107. dec_enc_attn_list += [dec_enc_attn]
  108. if return_attns:
  109. return dec_output, dec_slf_attn_list, dec_enc_attn_list
  110. return dec_output,
  111. class Transformer(nn.Module):
  112. ''' A sequence to sequence model with attention mechanism. '''
  113. def __init__(
  114. self,
  115. n_src_vocab, n_tgt_vocab, len_max_seq,
  116. d_word_vec=512, d_model=512, d_inner=2048,
  117. n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1,
  118. tgt_emb_prj_weight_sharing=True,
  119. emb_src_tgt_weight_sharing=True):
  120. super().__init__()
  121. self.encoder = Encoder(
  122. n_src_vocab=n_src_vocab, len_max_seq=len_max_seq,
  123. d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
  124. n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
  125. dropout=dropout)
  126. self.decoder = Decoder(
  127. n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq,
  128. d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
  129. n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
  130. dropout=dropout)
  131. self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False)
  132. nn.init.xavier_normal_(self.tgt_word_prj.weight)
  133. assert d_model == d_word_vec, \
  134. 'To facilitate the residual connections, \
  135. the dimensions of all module outputs shall be the same.'
  136. if tgt_emb_prj_weight_sharing:
  137. # Share the weight matrix between target word embedding & the final logit dense layer
  138. self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight
  139. self.x_logit_scale = (d_model ** -0.5)
  140. else:
  141. self.x_logit_scale = 1.
  142. if emb_src_tgt_weight_sharing:
  143. # Share the weight matrix between source & target word embeddings
  144. assert n_src_vocab == n_tgt_vocab, \
  145. "To share word embedding table, the vocabulary size of src/tgt shall be the same."
  146. self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight
  147. def forward(self, src_seq, src_pos, tgt_seq, tgt_pos):
  148. tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1]
  149. enc_output, *_ = self.encoder(src_seq, src_pos)
  150. dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output)
  151. seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale
  152. return seq_logit.view(-1, seq_logit.size(2))