You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TransformerModel.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # __author__="Danqing Wang"
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ==============================================================================
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import numpy as np
  21. import torch
  22. import torch.nn as nn
  23. from .Encoder import Encoder
  24. from tools.PositionEmbedding import get_sinusoid_encoding_table
  25. from fastNLP.core.const import Const
  26. from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoderLayer
  27. class TransformerModel(nn.Module):
  28. def __init__(self, hps, vocab):
  29. """
  30. :param hps:
  31. min_kernel_size: min kernel size for cnn encoder
  32. max_kernel_size: max kernel size for cnn encoder
  33. output_channel: output_channel number for cnn encoder
  34. hidden_size: hidden size for transformer
  35. n_layers: transfromer encoder layer
  36. n_head: multi head attention for transformer
  37. ffn_inner_hidden_size: FFN hiddens size
  38. atten_dropout_prob: dropout size
  39. doc_max_timesteps: max sentence number of the document
  40. :param vocab:
  41. """
  42. super(TransformerModel, self).__init__()
  43. self._hps = hps
  44. self._vocab = vocab
  45. self.encoder = Encoder(hps, vocab)
  46. self.sent_embedding_size = (hps.max_kernel_size - hps.min_kernel_size + 1) * hps.output_channel
  47. self.hidden_size = hps.hidden_size
  48. self.n_head = hps.n_head
  49. self.d_v = self.d_k = int(self.hidden_size / self.n_head)
  50. self.d_inner = hps.ffn_inner_hidden_size
  51. self.num_layers = hps.n_layers
  52. self.projection = nn.Linear(self.sent_embedding_size, self.hidden_size)
  53. self.sent_pos_embed = nn.Embedding.from_pretrained(
  54. get_sinusoid_encoding_table(hps.doc_max_timesteps + 1, self.hidden_size, padding_idx=0), freeze=True)
  55. self.layer_stack = nn.ModuleList([
  56. TransformerSeq2SeqEncoderLayer(d_model = self.hidden_size, n_head = self.n_head, dim_ff = self.d_inner,
  57. dropout = hps.atten_dropout_prob)
  58. for _ in range(self.num_layers)])
  59. self.wh = nn.Linear(self.hidden_size, 2)
  60. def forward(self, words, seq_len):
  61. """
  62. :param input: [batch_size, N, seq_len]
  63. :param input_len: [batch_size, N]
  64. :param return_atten: bool
  65. :return:
  66. """
  67. # Sentence Encoder
  68. input = words
  69. input_len = seq_len
  70. self.sent_embedding = self.encoder(input) # [batch, N, Co * kernel_sizes]
  71. input_len = input_len.float() # [batch, N]
  72. # -- Prepare masks
  73. batch_size, N = input_len.size()
  74. self.slf_attn_mask = input_len.eq(0.0) # [batch, N]
  75. self.slf_attn_mask = self.slf_attn_mask.unsqueeze(1).expand(-1, N, -1) # [batch, N, N]
  76. self.non_pad_mask = input_len.unsqueeze(-1) # [batch, N, 1]
  77. input_doc_len = input_len.sum(dim=1).int() # [batch]
  78. sent_pos = torch.Tensor([np.hstack((np.arange(1, doclen + 1), np.zeros(N - doclen))) for doclen in input_doc_len])
  79. sent_pos = sent_pos.long().cuda() if self._hps.cuda else sent_pos.long()
  80. enc_output_state = self.projection(self.sent_embedding)
  81. enc_input = enc_output_state + self.sent_pos_embed(sent_pos)
  82. # self.enc_slf_attn = self.enc_slf_attn * self.non_pad_mask
  83. enc_input_list = []
  84. for enc_layer in self.layer_stack:
  85. # enc_output = [batch_size, N, hidden_size = n_head * d_v]
  86. # enc_slf_attn = [n_head * batch_size, N, N]
  87. enc_input = enc_layer(enc_input, encoder_mask=self.slf_attn_mask)
  88. enc_input_list += [enc_input]
  89. self.dec_output_state = torch.cat(enc_input_list[-4:]) # [4, batch_size, N, hidden_state]
  90. self.dec_output_state = self.dec_output_state.view(4, batch_size, N, -1)
  91. self.dec_output_state = self.dec_output_state.sum(0)
  92. p_sent = self.wh(self.dec_output_state) # [batch, N, 2]
  93. idx = None
  94. if self._hps.m == 0:
  95. prediction = p_sent.view(-1, 2).max(1)[1]
  96. prediction = prediction.view(batch_size, -1)
  97. else:
  98. mask_output = torch.exp(p_sent[:, :, 1]) # # [batch, N]
  99. mask_output = mask_output * input_len.float()
  100. topk, idx = torch.topk(mask_output, self._hps.m)
  101. prediction = torch.zeros(batch_size, N).scatter_(1, idx.data.cpu(), 1)
  102. prediction = prediction.long().view(batch_size, -1)
  103. if self._hps.cuda:
  104. prediction = prediction.cuda()
  105. # print((p_sent.size(), prediction.size(), idx.size()))
  106. # [batch, N, 2], [batch, N], [batch, hps.m]
  107. return {"pred": p_sent, "prediction": prediction, "pred_idx": idx}