import torch import torch.nn as nn import torch.nn.functional as F from fastNLP.models.base_model import BaseModel from fastNLP.embeddings import TokenEmbedding from fastNLP.core.const import Const from fastNLP.core.utils import seq_len_to_mask class ESIMModel(BaseModel): def __init__(self, init_embedding: TokenEmbedding, hidden_size=None, num_labels=3, dropout_rate=0.3, dropout_embed=0.1): super(ESIMModel, self).__init__() self.embedding = init_embedding self.dropout_embed = EmbedDropout(p=dropout_embed) if hidden_size is None: hidden_size = self.embedding.embed_size self.rnn = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) # self.rnn = LSTM(self.embedding.embed_size, hidden_size, dropout=dropout_rate, bidirectional=True) self.interfere = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(8 * hidden_size, hidden_size), nn.ReLU()) nn.init.xavier_uniform_(self.interfere[1].weight.data) self.bi_attention = SoftmaxAttention() self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) # self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True,) self.classifier = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(8 * hidden_size, hidden_size), nn.Tanh(), nn.Dropout(p=dropout_rate), nn.Linear(hidden_size, num_labels)) self.dropout_rnn = nn.Dropout(p=dropout_rate) nn.init.xavier_uniform_(self.classifier[1].weight.data) nn.init.xavier_uniform_(self.classifier[4].weight.data) def forward(self, words1, words2, seq_len1, seq_len2): """ :param words1: [batch, seq_len] :param words2: [batch, seq_len] :param seq_len1: [batch] :param seq_len2: [batch] :return: """ mask1 = seq_len_to_mask(seq_len1, words1.size(1)) mask2 = seq_len_to_mask(seq_len2, words2.size(1)) a0 = self.embedding(words1) # B * len * emb_dim b0 = self.embedding(words2) a0, b0 = self.dropout_embed(a0), self.dropout_embed(b0) a = self.rnn(a0, mask1.byte()) # a: [B, PL, 2 * H] b = self.rnn(b0, mask2.byte()) # a = self.dropout_rnn(self.rnn(a0, seq_len1)[0]) # a: [B, PL, 2 * H] # b = self.dropout_rnn(self.rnn(b0, seq_len2)[0]) ai, bi = self.bi_attention(a, mask1, b, mask2) a_ = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 8 * H] b_ = torch.cat((b, bi, b - bi, b * bi), dim=2) a_f = self.interfere(a_) b_f = self.interfere(b_) a_h = self.rnn_high(a_f, mask1.byte()) # ma: [B, PL, 2 * H] b_h = self.rnn_high(b_f, mask2.byte()) # a_h = self.dropout_rnn(self.rnn_high(a_f, seq_len1)[0]) # ma: [B, PL, 2 * H] # b_h = self.dropout_rnn(self.rnn_high(b_f, seq_len2)[0]) a_avg = self.mean_pooling(a_h, mask1, dim=1) a_max, _ = self.max_pooling(a_h, mask1, dim=1) b_avg = self.mean_pooling(b_h, mask2, dim=1) b_max, _ = self.max_pooling(b_h, mask2, dim=1) out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] logits = torch.tanh(self.classifier(out)) # logits = self.classifier(out) return {Const.OUTPUT: logits} def predict(self, words1, words2, seq_len1, seq_len2): pred = self.forward(words1, words2, seq_len1, seq_len2)[Const.OUTPUT].argmax(-1) return {Const.OUTPUT: pred} # input [batch_size, len , hidden] # mask [batch_size, len] (111...00) @staticmethod def mean_pooling(input, mask, dim=1): masks = mask.view(mask.size(0), mask.size(1), -1).float() return torch.sum(input * masks, dim=dim) / torch.sum(masks, dim=1) @staticmethod def max_pooling(input, mask, dim=1): my_inf = 10e12 masks = mask.view(mask.size(0), mask.size(1), -1) masks = masks.expand(-1, -1, input.size(2)).float() return torch.max(input + masks.le(0.5).float() * -my_inf, dim=dim) class EmbedDropout(nn.Dropout): def forward(self, sequences_batch): ones = sequences_batch.data.new_ones(sequences_batch.shape[0], sequences_batch.shape[-1]) dropout_mask = nn.functional.dropout(ones, self.p, self.training, inplace=False) return dropout_mask.unsqueeze(1) * sequences_batch class BiRNN(nn.Module): def __init__(self, input_size, hidden_size, dropout_rate=0.3): super(BiRNN, self).__init__() self.dropout_rate = dropout_rate self.rnn = nn.LSTM(input_size, hidden_size, num_layers=1, bidirectional=True, batch_first=True) def forward(self, x, x_mask): # Sort x lengths = x_mask.data.eq(True).long().sum(1) _, idx_sort = torch.sort(lengths, dim=0, descending=True) _, idx_unsort = torch.sort(idx_sort, dim=0) lengths = list(lengths[idx_sort]) x = x.index_select(0, idx_sort) # Pack it up rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True) # Apply dropout to input if self.dropout_rate > 0: dropout_input = F.dropout(rnn_input.data, p=self.dropout_rate, training=self.training) rnn_input = nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes) output = self.rnn(rnn_input)[0] # Unpack everything output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0] output = output.index_select(0, idx_unsort) if output.size(1) != x_mask.size(1): padding = torch.zeros(output.size(0), x_mask.size(1) - output.size(1), output.size(2)).type(output.data.type()) output = torch.cat([output, padding], 1) return output def masked_softmax(tensor, mask): tensor_shape = tensor.size() reshaped_tensor = tensor.view(-1, tensor_shape[-1]) # Reshape the mask so it matches the size of the input tensor. while mask.dim() < tensor.dim(): mask = mask.unsqueeze(1) mask = mask.expand_as(tensor).contiguous().float() reshaped_mask = mask.view(-1, mask.size()[-1]) result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1) result = result * reshaped_mask # 1e-13 is added to avoid divisions by zero. result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) return result.view(*tensor_shape) def weighted_sum(tensor, weights, mask): w_sum = weights.bmm(tensor) while mask.dim() < w_sum.dim(): mask = mask.unsqueeze(1) mask = mask.transpose(-1, -2) mask = mask.expand_as(w_sum).contiguous().float() return w_sum * mask class SoftmaxAttention(nn.Module): def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) .contiguous()) prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask) hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2) .contiguous(), premise_mask) attended_premises = weighted_sum(hypothesis_batch, prem_hyp_attn, premise_mask) attended_hypotheses = weighted_sum(premise_batch, hyp_prem_attn, hypothesis_mask) return attended_premises, attended_hypotheses