You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

esim.py 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.nn import CrossEntropyLoss
  5. from fastNLP.models import BaseModel
  6. from fastNLP.modules.encoder.embedding import TokenEmbedding
  7. from fastNLP.modules.encoder.lstm import LSTM
  8. from fastNLP.core.const import Const
  9. from fastNLP.core.utils import seq_len_to_mask
  10. class ESIMModel(BaseModel):
  11. def __init__(self, init_embedding: TokenEmbedding, hidden_size=None, num_labels=3, dropout_rate=0.3,
  12. dropout_embed=0.1):
  13. super(ESIMModel, self).__init__()
  14. self.embedding = init_embedding
  15. self.dropout_embed = EmbedDropout(p=dropout_embed)
  16. if hidden_size is None:
  17. hidden_size = self.embedding.embed_size
  18. self.rnn = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate)
  19. # self.rnn = LSTM(self.embedding.embed_size, hidden_size, dropout=dropout_rate, bidirectional=True)
  20. self.interfere = nn.Sequential(nn.Dropout(p=dropout_rate),
  21. nn.Linear(8 * hidden_size, hidden_size),
  22. nn.ReLU())
  23. nn.init.xavier_uniform_(self.interfere[1].weight.data)
  24. self.bi_attention = SoftmaxAttention()
  25. self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate)
  26. # self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True,)
  27. self.classifier = nn.Sequential(nn.Dropout(p=dropout_rate),
  28. nn.Linear(8 * hidden_size, hidden_size),
  29. nn.Tanh(),
  30. nn.Dropout(p=dropout_rate),
  31. nn.Linear(hidden_size, num_labels))
  32. self.dropout_rnn = nn.Dropout(p=dropout_rate)
  33. nn.init.xavier_uniform_(self.classifier[1].weight.data)
  34. nn.init.xavier_uniform_(self.classifier[4].weight.data)
  35. def forward(self, words1, words2, seq_len1, seq_len2, target=None):
  36. """
  37. :param words1: [batch, seq_len]
  38. :param words2: [batch, seq_len]
  39. :param seq_len1: [batch]
  40. :param seq_len2: [batch]
  41. :param target:
  42. :return:
  43. """
  44. mask1 = seq_len_to_mask(seq_len1, words1.size(1))
  45. mask2 = seq_len_to_mask(seq_len2, words2.size(1))
  46. a0 = self.embedding(words1) # B * len * emb_dim
  47. b0 = self.embedding(words2)
  48. a0, b0 = self.dropout_embed(a0), self.dropout_embed(b0)
  49. a = self.rnn(a0, mask1.byte()) # a: [B, PL, 2 * H]
  50. b = self.rnn(b0, mask2.byte())
  51. # a = self.dropout_rnn(self.rnn(a0, seq_len1)[0]) # a: [B, PL, 2 * H]
  52. # b = self.dropout_rnn(self.rnn(b0, seq_len2)[0])
  53. ai, bi = self.bi_attention(a, mask1, b, mask2)
  54. a_ = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 8 * H]
  55. b_ = torch.cat((b, bi, b - bi, b * bi), dim=2)
  56. a_f = self.interfere(a_)
  57. b_f = self.interfere(b_)
  58. a_h = self.rnn_high(a_f, mask1.byte()) # ma: [B, PL, 2 * H]
  59. b_h = self.rnn_high(b_f, mask2.byte())
  60. # a_h = self.dropout_rnn(self.rnn_high(a_f, seq_len1)[0]) # ma: [B, PL, 2 * H]
  61. # b_h = self.dropout_rnn(self.rnn_high(b_f, seq_len2)[0])
  62. a_avg = self.mean_pooling(a_h, mask1, dim=1)
  63. a_max, _ = self.max_pooling(a_h, mask1, dim=1)
  64. b_avg = self.mean_pooling(b_h, mask2, dim=1)
  65. b_max, _ = self.max_pooling(b_h, mask2, dim=1)
  66. out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H]
  67. logits = torch.tanh(self.classifier(out))
  68. if target is not None:
  69. loss_fct = CrossEntropyLoss()
  70. loss = loss_fct(logits, target)
  71. return {Const.LOSS: loss, Const.OUTPUT: logits}
  72. else:
  73. return {Const.OUTPUT: logits}
  74. def predict(self, **kwargs):
  75. return self.forward(**kwargs)
  76. # input [batch_size, len , hidden]
  77. # mask [batch_size, len] (111...00)
  78. @staticmethod
  79. def mean_pooling(input, mask, dim=1):
  80. masks = mask.view(mask.size(0), mask.size(1), -1).float()
  81. return torch.sum(input * masks, dim=dim) / torch.sum(masks, dim=1)
  82. @staticmethod
  83. def max_pooling(input, mask, dim=1):
  84. my_inf = 10e12
  85. masks = mask.view(mask.size(0), mask.size(1), -1)
  86. masks = masks.expand(-1, -1, input.size(2)).float()
  87. return torch.max(input + masks.le(0.5).float() * -my_inf, dim=dim)
  88. class EmbedDropout(nn.Dropout):
  89. def forward(self, sequences_batch):
  90. ones = sequences_batch.data.new_ones(sequences_batch.shape[0], sequences_batch.shape[-1])
  91. dropout_mask = nn.functional.dropout(ones, self.p, self.training, inplace=False)
  92. return dropout_mask.unsqueeze(1) * sequences_batch
  93. class BiRNN(nn.Module):
  94. def __init__(self, input_size, hidden_size, dropout_rate=0.3):
  95. super(BiRNN, self).__init__()
  96. self.dropout_rate = dropout_rate
  97. self.rnn = nn.LSTM(input_size, hidden_size,
  98. num_layers=1,
  99. bidirectional=True,
  100. batch_first=True)
  101. def forward(self, x, x_mask):
  102. # Sort x
  103. lengths = x_mask.data.eq(1).long().sum(1).squeeze()
  104. _, idx_sort = torch.sort(lengths, dim=0, descending=True)
  105. _, idx_unsort = torch.sort(idx_sort, dim=0)
  106. lengths = list(lengths[idx_sort])
  107. x = x.index_select(0, idx_sort)
  108. # Pack it up
  109. rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True)
  110. # Apply dropout to input
  111. if self.dropout_rate > 0:
  112. dropout_input = F.dropout(rnn_input.data, p=self.dropout_rate, training=self.training)
  113. rnn_input = nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes)
  114. output = self.rnn(rnn_input)[0]
  115. # Unpack everything
  116. output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0]
  117. output = output.index_select(0, idx_unsort)
  118. if output.size(1) != x_mask.size(1):
  119. padding = torch.zeros(output.size(0),
  120. x_mask.size(1) - output.size(1),
  121. output.size(2)).type(output.data.type())
  122. output = torch.cat([output, padding], 1)
  123. return output
  124. def masked_softmax(tensor, mask):
  125. tensor_shape = tensor.size()
  126. reshaped_tensor = tensor.view(-1, tensor_shape[-1])
  127. # Reshape the mask so it matches the size of the input tensor.
  128. while mask.dim() < tensor.dim():
  129. mask = mask.unsqueeze(1)
  130. mask = mask.expand_as(tensor).contiguous().float()
  131. reshaped_mask = mask.view(-1, mask.size()[-1])
  132. result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1)
  133. result = result * reshaped_mask
  134. # 1e-13 is added to avoid divisions by zero.
  135. result = result / (result.sum(dim=-1, keepdim=True) + 1e-13)
  136. return result.view(*tensor_shape)
  137. def weighted_sum(tensor, weights, mask):
  138. w_sum = weights.bmm(tensor)
  139. while mask.dim() < w_sum.dim():
  140. mask = mask.unsqueeze(1)
  141. mask = mask.transpose(-1, -2)
  142. mask = mask.expand_as(w_sum).contiguous().float()
  143. return w_sum * mask
  144. class SoftmaxAttention(nn.Module):
  145. def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask):
  146. similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1)
  147. .contiguous())
  148. prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask)
  149. hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2)
  150. .contiguous(),
  151. premise_mask)
  152. attended_premises = weighted_sum(hypothesis_batch,
  153. prem_hyp_attn,
  154. premise_mask)
  155. attended_hypotheses = weighted_sum(premise_batch,
  156. hyp_prem_attn,
  157. hypothesis_mask)
  158. return attended_premises, attended_hypotheses