You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

HAN.py 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import torch
  2. import torch.nn as nn
  3. from torch.autograd import Variable
  4. from fastNLP.modules.utils import get_embeddings
  5. from fastNLP.core import Const as C
  6. def pack_sequence(tensor_seq, padding_value=0.0):
  7. if len(tensor_seq) <= 0:
  8. return
  9. length = [v.size(0) for v in tensor_seq]
  10. max_len = max(length)
  11. size = [len(tensor_seq), max_len]
  12. size.extend(list(tensor_seq[0].size()[1:]))
  13. ans = torch.Tensor(*size).fill_(padding_value)
  14. if tensor_seq[0].data.is_cuda:
  15. ans = ans.cuda()
  16. ans = Variable(ans)
  17. for i, v in enumerate(tensor_seq):
  18. ans[i, :length[i], :] = v
  19. return ans
  20. class HANCLS(nn.Module):
  21. def __init__(self, init_embed, num_cls):
  22. super(HANCLS, self).__init__()
  23. self.embed = get_embeddings(init_embed)
  24. self.han = HAN(input_size=300,
  25. output_size=num_cls,
  26. word_hidden_size=50, word_num_layers=1, word_context_size=100,
  27. sent_hidden_size=50, sent_num_layers=1, sent_context_size=100
  28. )
  29. def forward(self, input_sents):
  30. # input_sents [B, num_sents, seq-len] dtype long
  31. # target
  32. B, num_sents, seq_len = input_sents.size()
  33. input_sents = input_sents.view(-1, seq_len) # flat
  34. words_embed = self.embed(input_sents) # should be [B*num-sent, seqlen , word-dim]
  35. words_embed = words_embed.view(B, num_sents, seq_len, -1) # recover # [B, num-sent, seqlen , word-dim]
  36. out = self.han(words_embed)
  37. return {C.OUTPUT: out}
  38. def predict(self, input_sents):
  39. x = self.forward(input_sents)[C.OUTPUT]
  40. return {C.OUTPUT: torch.argmax(x, 1)}
  41. class HAN(nn.Module):
  42. def __init__(self, input_size, output_size,
  43. word_hidden_size, word_num_layers, word_context_size,
  44. sent_hidden_size, sent_num_layers, sent_context_size):
  45. super(HAN, self).__init__()
  46. self.word_layer = AttentionNet(input_size,
  47. word_hidden_size,
  48. word_num_layers,
  49. word_context_size)
  50. self.sent_layer = AttentionNet(2 * word_hidden_size,
  51. sent_hidden_size,
  52. sent_num_layers,
  53. sent_context_size)
  54. self.output_layer = nn.Linear(2 * sent_hidden_size, output_size)
  55. self.softmax = nn.LogSoftmax(dim=1)
  56. def forward(self, batch_doc):
  57. # input is a sequence of matrix
  58. doc_vec_list = []
  59. for doc in batch_doc:
  60. sent_mat = self.word_layer(doc) # doc's dim (num_sent, seq_len, word_dim)
  61. doc_vec_list.append(sent_mat) # sent_mat's dim (num_sent, vec_dim)
  62. doc_vec = self.sent_layer(pack_sequence(doc_vec_list))
  63. output = self.softmax(self.output_layer(doc_vec))
  64. return output
  65. class AttentionNet(nn.Module):
  66. def __init__(self, input_size, gru_hidden_size, gru_num_layers, context_vec_size):
  67. super(AttentionNet, self).__init__()
  68. self.input_size = input_size
  69. self.gru_hidden_size = gru_hidden_size
  70. self.gru_num_layers = gru_num_layers
  71. self.context_vec_size = context_vec_size
  72. # Encoder
  73. self.gru = nn.GRU(input_size=input_size,
  74. hidden_size=gru_hidden_size,
  75. num_layers=gru_num_layers,
  76. batch_first=True,
  77. bidirectional=True)
  78. # Attention
  79. self.fc = nn.Linear(2 * gru_hidden_size, context_vec_size)
  80. self.tanh = nn.Tanh()
  81. self.softmax = nn.Softmax(dim=1)
  82. # context vector
  83. self.context_vec = nn.Parameter(torch.Tensor(context_vec_size, 1))
  84. self.context_vec.data.uniform_(-0.1, 0.1)
  85. def forward(self, inputs):
  86. # GRU part
  87. h_t, hidden = self.gru(inputs) # inputs's dim (batch_size, seq_len, word_dim)
  88. u = self.tanh(self.fc(h_t))
  89. # Attention part
  90. alpha = self.softmax(torch.matmul(u, self.context_vec)) # u's dim (batch_size, seq_len, context_vec_size)
  91. output = torch.bmm(torch.transpose(h_t, 1, 2), alpha) # alpha's dim (batch_size, seq_len, 1)
  92. return torch.squeeze(output, dim=2) # output's dim (batch_size, 2*hidden_size, 1)