''' Define the Transformer model ''' import torch import torch.nn as nn import numpy as np import transformer.Constants as Constants from transformer.Layers import EncoderLayer, DecoderLayer __author__ = "Yu-Hsiang Huang" def get_non_pad_mask(seq): assert seq.dim() == 2 return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): ''' Sinusoid position encoding table ''' def cal_angle(position, hid_idx): return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) def get_posi_angle_vec(position): return [cal_angle(position, hid_j) for hid_j in range(d_hid)] sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 if padding_idx is not None: # zero vector for padding dimension sinusoid_table[padding_idx] = 0. return torch.FloatTensor(sinusoid_table) def get_attn_key_pad_mask(seq_k, seq_q): ''' For masking out the padding part of key sequence. ''' # Expand to fit the shape of key query attention matrix. len_q = seq_q.size(1) padding_mask = seq_k.eq(Constants.PAD) padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk return padding_mask def get_subsequent_mask(seq): ''' For masking out the subsequent info. ''' sz_b, len_s = seq.size() subsequent_mask = torch.triu( torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls return subsequent_mask class Encoder(nn.Module): ''' A encoder model with self attention mechanism. ''' def __init__( self, n_src_vocab, len_max_seq, d_word_vec, n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1): super().__init__() n_position = len_max_seq + 1 self.src_word_emb = nn.Embedding( n_src_vocab, d_word_vec, padding_idx=Constants.PAD) self.position_enc = nn.Embedding.from_pretrained( get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), freeze=True) self.layer_stack = nn.ModuleList([ EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)]) def forward(self, src_seq, src_pos, return_attns=False): enc_slf_attn_list = [] # -- Prepare masks slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq) non_pad_mask = get_non_pad_mask(src_seq) # -- Forward enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos) for enc_layer in self.layer_stack: enc_output, enc_slf_attn = enc_layer( enc_output, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask) if return_attns: enc_slf_attn_list += [enc_slf_attn] if return_attns: return enc_output, enc_slf_attn_list return enc_output, class Decoder(nn.Module): ''' A decoder model with self attention mechanism. ''' def __init__( self, n_tgt_vocab, len_max_seq, d_word_vec, n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1): super().__init__() n_position = len_max_seq + 1 self.tgt_word_emb = nn.Embedding( n_tgt_vocab, d_word_vec, padding_idx=Constants.PAD) self.position_enc = nn.Embedding.from_pretrained( get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0), freeze=True) self.layer_stack = nn.ModuleList([ DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)]) def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False): dec_slf_attn_list, dec_enc_attn_list = [], [] # -- Prepare masks non_pad_mask = get_non_pad_mask(tgt_seq) slf_attn_mask_subseq = get_subsequent_mask(tgt_seq) slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq) slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0) dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq) # -- Forward dec_output = self.tgt_word_emb(tgt_seq) + self.position_enc(tgt_pos) for dec_layer in self.layer_stack: dec_output, dec_slf_attn, dec_enc_attn = dec_layer( dec_output, enc_output, non_pad_mask=non_pad_mask, slf_attn_mask=slf_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask) if return_attns: dec_slf_attn_list += [dec_slf_attn] dec_enc_attn_list += [dec_enc_attn] if return_attns: return dec_output, dec_slf_attn_list, dec_enc_attn_list return dec_output, class Transformer(nn.Module): ''' A sequence to sequence model with attention mechanism. ''' def __init__( self, n_src_vocab, n_tgt_vocab, len_max_seq, d_word_vec=512, d_model=512, d_inner=2048, n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1, tgt_emb_prj_weight_sharing=True, emb_src_tgt_weight_sharing=True): super().__init__() self.encoder = Encoder( n_src_vocab=n_src_vocab, len_max_seq=len_max_seq, d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, dropout=dropout) self.decoder = Decoder( n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq, d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner, n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v, dropout=dropout) self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False) nn.init.xavier_normal_(self.tgt_word_prj.weight) assert d_model == d_word_vec, \ 'To facilitate the residual connections, \ the dimensions of all module outputs shall be the same.' if tgt_emb_prj_weight_sharing: # Share the weight matrix between target word embedding & the final logit dense layer self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight self.x_logit_scale = (d_model ** -0.5) else: self.x_logit_scale = 1. if emb_src_tgt_weight_sharing: # Share the weight matrix between source & target word embeddings assert n_src_vocab == n_tgt_vocab, \ "To share word embedding table, the vocabulary size of src/tgt shall be the same." self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight def forward(self, src_seq, src_pos, tgt_seq, tgt_pos): tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1] enc_output, *_ = self.encoder(src_seq, src_pos) dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output) seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale return seq_logit.view(-1, seq_logit.size(2))