You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TForiginal.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # __author__="Danqing Wang"
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ==============================================================================
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import numpy as np
  21. import torch
  22. import torch.nn as nn
  23. from .Encoder import Encoder
  24. # from tools.Encoder import Encoder
  25. from tools.PositionEmbedding import get_sinusoid_encoding_table
  26. from tools.logger import *
  27. from fastNLP.core.const import Const
  28. from fastNLP.modules.encoder.transformer import TransformerEncoder
  29. from transformer.Layers import EncoderLayer
  30. class TransformerModel(nn.Module):
  31. def __init__(self, hps, embed):
  32. """
  33. :param hps:
  34. min_kernel_size: min kernel size for cnn encoder
  35. max_kernel_size: max kernel size for cnn encoder
  36. output_channel: output_channel number for cnn encoder
  37. hidden_size: hidden size for transformer
  38. n_layers: transfromer encoder layer
  39. n_head: multi head attention for transformer
  40. ffn_inner_hidden_size: FFN hiddens size
  41. atten_dropout_prob: dropout size
  42. doc_max_timesteps: max sentence number of the document
  43. :param vocab:
  44. """
  45. super(TransformerModel, self).__init__()
  46. self._hps = hps
  47. self.encoder = Encoder(hps, embed)
  48. self.sent_embedding_size = (hps.max_kernel_size - hps.min_kernel_size + 1) * hps.output_channel
  49. self.hidden_size = hps.hidden_size
  50. self.n_head = hps.n_head
  51. self.d_v = self.d_k = int(self.hidden_size / self.n_head)
  52. self.d_inner = hps.ffn_inner_hidden_size
  53. self.num_layers = hps.n_layers
  54. self.projection = nn.Linear(self.sent_embedding_size, self.hidden_size)
  55. self.sent_pos_embed = nn.Embedding.from_pretrained(
  56. get_sinusoid_encoding_table(hps.doc_max_timesteps + 1, self.hidden_size, padding_idx=0), freeze=True)
  57. self.layer_stack = nn.ModuleList([
  58. EncoderLayer(self.hidden_size, self.d_inner, self.n_head, self.d_k, self.d_v,
  59. dropout=hps.atten_dropout_prob)
  60. for _ in range(self.num_layers)])
  61. self.wh = nn.Linear(self.hidden_size, 2)
  62. def forward(self, words, seq_len):
  63. """
  64. :param input: [batch_size, N, seq_len]
  65. :param input_len: [batch_size, N]
  66. :return:
  67. """
  68. # Sentence Encoder
  69. input = words
  70. input_len = seq_len
  71. self.sent_embedding = self.encoder(input) # [batch, N, Co * kernel_sizes]
  72. input_len = input_len.float() # [batch, N]
  73. # -- Prepare masks
  74. batch_size, N = input_len.size()
  75. self.slf_attn_mask = input_len.eq(0.0) # [batch, N]
  76. self.slf_attn_mask = self.slf_attn_mask.unsqueeze(1).expand(-1, N, -1) # [batch, N, N]
  77. self.non_pad_mask = input_len.unsqueeze(-1) # [batch, N, 1]
  78. input_doc_len = input_len.sum(dim=1).int() # [batch]
  79. sent_pos = torch.Tensor(
  80. [np.hstack((np.arange(1, doclen + 1), np.zeros(N - doclen))) for doclen in input_doc_len])
  81. sent_pos = sent_pos.long().cuda() if self._hps.cuda else sent_pos.long()
  82. enc_output_state = self.projection(self.sent_embedding)
  83. enc_input = enc_output_state + self.sent_pos_embed(sent_pos)
  84. # self.enc_slf_attn = self.enc_slf_attn * self.non_pad_mask
  85. enc_input_list = []
  86. for enc_layer in self.layer_stack:
  87. # enc_output = [batch_size, N, hidden_size = n_head * d_v]
  88. # enc_slf_attn = [n_head * batch_size, N, N]
  89. enc_input, enc_slf_atten = enc_layer(enc_input, non_pad_mask=self.non_pad_mask,
  90. slf_attn_mask=self.slf_attn_mask)
  91. enc_input_list += [enc_input]
  92. self.dec_output_state = torch.cat(enc_input_list[-4:]) # [4, batch_size, N, hidden_state]
  93. self.dec_output_state = self.dec_output_state.view(4, batch_size, N, -1)
  94. self.dec_output_state = self.dec_output_state.sum(0)
  95. p_sent = self.wh(self.dec_output_state) # [batch, N, 2]
  96. idx = None
  97. if self._hps.m == 0:
  98. prediction = p_sent.view(-1, 2).max(1)[1]
  99. prediction = prediction.view(batch_size, -1)
  100. else:
  101. mask_output = torch.exp(p_sent[:, :, 1]) # # [batch, N]
  102. mask_output = mask_output.masked_fill(input_len.eq(0), 0)
  103. topk, idx = torch.topk(mask_output, self._hps.m)
  104. prediction = torch.zeros(batch_size, N).scatter_(1, idx.data.cpu(), 1)
  105. prediction = prediction.long().view(batch_size, -1)
  106. if self._hps.cuda:
  107. prediction = prediction.cuda()
  108. # logger.debug(((p_sent.size(), prediction.size(), idx.size())))
  109. return {"p_sent": p_sent, "prediction": prediction, "pred_idx": idx}