You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TForiginal.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # __author__="Danqing Wang"
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ==============================================================================
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import numpy as np
  21. import torch
  22. import torch.nn as nn
  23. from .Encoder import Encoder
  24. # from tools.Encoder import Encoder
  25. from tools.PositionEmbedding import get_sinusoid_encoding_table
  26. from tools.logger import *
  27. from fastNLP.core.const import Const
  28. from transformer.Layers import EncoderLayer
  29. class TransformerModel(nn.Module):
  30. def __init__(self, hps, embed):
  31. """
  32. :param hps:
  33. min_kernel_size: min kernel size for cnn encoder
  34. max_kernel_size: max kernel size for cnn encoder
  35. output_channel: output_channel number for cnn encoder
  36. hidden_size: hidden size for transformer
  37. n_layers: transfromer encoder layer
  38. n_head: multi head attention for transformer
  39. ffn_inner_hidden_size: FFN hiddens size
  40. atten_dropout_prob: dropout size
  41. doc_max_timesteps: max sentence number of the document
  42. :param embed: word embedding
  43. """
  44. super(TransformerModel, self).__init__()
  45. self._hps = hps
  46. self.encoder = Encoder(hps, embed)
  47. self.sent_embedding_size = (hps.max_kernel_size - hps.min_kernel_size + 1) * hps.output_channel
  48. self.hidden_size = hps.hidden_size
  49. self.n_head = hps.n_head
  50. self.d_v = self.d_k = int(self.hidden_size / self.n_head)
  51. self.d_inner = hps.ffn_inner_hidden_size
  52. self.num_layers = hps.n_layers
  53. self.projection = nn.Linear(self.sent_embedding_size, self.hidden_size)
  54. self.sent_pos_embed = nn.Embedding.from_pretrained(
  55. get_sinusoid_encoding_table(hps.doc_max_timesteps + 1, self.hidden_size, padding_idx=0), freeze=True)
  56. self.layer_stack = nn.ModuleList([
  57. EncoderLayer(self.hidden_size, self.d_inner, self.n_head, self.d_k, self.d_v,
  58. dropout=hps.atten_dropout_prob)
  59. for _ in range(self.num_layers)])
  60. self.wh = nn.Linear(self.hidden_size, 2)
  61. def forward(self, words, seq_len):
  62. """
  63. :param input: [batch_size, N, seq_len]
  64. :param input_len: [batch_size, N]
  65. :return:
  66. """
  67. # Sentence Encoder
  68. input = words
  69. input_len = seq_len
  70. self.sent_embedding = self.encoder(input) # [batch, N, Co * kernel_sizes]
  71. input_len = input_len.float() # [batch, N]
  72. # -- Prepare masks
  73. batch_size, N = input_len.size()
  74. self.slf_attn_mask = input_len.eq(0.0) # [batch, N]
  75. self.slf_attn_mask = self.slf_attn_mask.unsqueeze(1).expand(-1, N, -1) # [batch, N, N]
  76. self.non_pad_mask = input_len.unsqueeze(-1) # [batch, N, 1]
  77. input_doc_len = input_len.sum(dim=1).int() # [batch]
  78. sent_pos = torch.Tensor(
  79. [np.hstack((np.arange(1, doclen + 1), np.zeros(N - doclen))) for doclen in input_doc_len])
  80. sent_pos = sent_pos.long().cuda() if self._hps.cuda else sent_pos.long()
  81. enc_output_state = self.projection(self.sent_embedding)
  82. enc_input = enc_output_state + self.sent_pos_embed(sent_pos)
  83. # self.enc_slf_attn = self.enc_slf_attn * self.non_pad_mask
  84. enc_input_list = []
  85. for enc_layer in self.layer_stack:
  86. # enc_output = [batch_size, N, hidden_size = n_head * d_v]
  87. # enc_slf_attn = [n_head * batch_size, N, N]
  88. enc_input, enc_slf_atten = enc_layer(enc_input, non_pad_mask=self.non_pad_mask,
  89. slf_attn_mask=self.slf_attn_mask)
  90. enc_input_list += [enc_input]
  91. self.dec_output_state = torch.cat(enc_input_list[-4:]) # [4, batch_size, N, hidden_state]
  92. self.dec_output_state = self.dec_output_state.view(4, batch_size, N, -1)
  93. self.dec_output_state = self.dec_output_state.sum(0)
  94. p_sent = self.wh(self.dec_output_state) # [batch, N, 2]
  95. idx = None
  96. if self._hps.m == 0:
  97. prediction = p_sent.view(-1, 2).max(1)[1]
  98. prediction = prediction.view(batch_size, -1)
  99. else:
  100. mask_output = torch.exp(p_sent[:, :, 1]) # # [batch, N]
  101. mask_output = mask_output.masked_fill(input_len.eq(0), 0)
  102. topk, idx = torch.topk(mask_output, self._hps.m)
  103. prediction = torch.zeros(batch_size, N).scatter_(1, idx.data.cpu(), 1)
  104. prediction = prediction.long().view(batch_size, -1)
  105. if self._hps.cuda:
  106. prediction = prediction.cuda()
  107. # logger.debug(((p_sent.size(), prediction.size(), idx.size())))
  108. return {"p_sent": p_sent, "prediction": prediction, "pred_idx": idx}