You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cntn.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import numpy as np
  5. from fastNLP.models.base_model import BaseModel
  6. from fastNLP.embeddings import TokenEmbedding
  7. from fastNLP.core.const import Const
  8. class DynamicKMaxPooling(nn.Module):
  9. """
  10. :param k_top: Fixed number of pooling output features for the topmost convolutional layer.
  11. :param l: Number of convolutional layers.
  12. """
  13. def __init__(self, k_top, l):
  14. super(DynamicKMaxPooling, self).__init__()
  15. self.k_top = k_top
  16. self.L = l
  17. def forward(self, x, l):
  18. """
  19. :param x: Input sequence.
  20. :param l: Current convolutional layers.
  21. """
  22. s = x.size()[3]
  23. k_ll = ((self.L - l) / self.L) * s
  24. k_l = int(round(max(self.k_top, np.ceil(k_ll))))
  25. out = F.adaptive_max_pool2d(x, (x.size()[2], k_l))
  26. return out
  27. class CNTNModel(BaseModel):
  28. """
  29. 使用CNN进行问答匹配的模型
  30. 'Qiu, Xipeng, and Xuanjing Huang.
  31. Convolutional neural tensor network architecture for community-based question answering.
  32. Twenty-Fourth International Joint Conference on Artificial Intelligence. 2015.'
  33. :param init_embedding: Embedding.
  34. :param ns: Sentence embedding size.
  35. :param k_top: Fixed number of pooling output features for the topmost convolutional layer.
  36. :param num_labels: Number of labels.
  37. :param depth: Number of convolutional layers.
  38. :param r: Number of weight tensor slices.
  39. :param drop_rate: Dropout rate.
  40. """
  41. def __init__(self, init_embedding: TokenEmbedding, ns=200, k_top=10, num_labels=2, depth=2, r=5,
  42. dropout_rate=0.3):
  43. super(CNTNModel, self).__init__()
  44. self.embedding = init_embedding
  45. self.depth = depth
  46. self.kmaxpooling = DynamicKMaxPooling(k_top, depth)
  47. self.conv_q = nn.ModuleList()
  48. self.conv_a = nn.ModuleList()
  49. width = self.embedding.embed_size
  50. for i in range(depth):
  51. self.conv_q.append(nn.Sequential(
  52. nn.Dropout(p=dropout_rate),
  53. nn.Conv2d(
  54. in_channels=1,
  55. out_channels=width // 2,
  56. kernel_size=(width, 3),
  57. padding=(0, 2))
  58. ))
  59. self.conv_a.append(nn.Sequential(
  60. nn.Dropout(p=dropout_rate),
  61. nn.Conv2d(
  62. in_channels=1,
  63. out_channels=width // 2,
  64. kernel_size=(width, 3),
  65. padding=(0, 2))
  66. ))
  67. width = width // 2
  68. self.fc_q = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(width * k_top, ns))
  69. self.fc_a = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(width * k_top, ns))
  70. self.weight_M = nn.Bilinear(ns, ns, r)
  71. self.weight_V = nn.Linear(2 * ns, r)
  72. self.weight_u = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(r, num_labels))
  73. def forward(self, words1, words2, seq_len1, seq_len2):
  74. """
  75. :param words1: [batch, seq_len, emb_size] Question.
  76. :param words2: [batch, seq_len, emb_size] Answer.
  77. :param seq_len1: [batch]
  78. :param seq_len2: [batch]
  79. :return:
  80. """
  81. in_q = self.embedding(words1)
  82. in_a = self.embedding(words2)
  83. in_q = in_q.permute(0, 2, 1).unsqueeze(1)
  84. in_a = in_a.permute(0, 2, 1).unsqueeze(1)
  85. for i in range(self.depth):
  86. in_q = F.relu(self.conv_q[i](in_q))
  87. in_q = in_q.squeeze().unsqueeze(1)
  88. in_q = self.kmaxpooling(in_q, i + 1)
  89. in_a = F.relu(self.conv_a[i](in_a))
  90. in_a = in_a.squeeze().unsqueeze(1)
  91. in_a = self.kmaxpooling(in_a, i + 1)
  92. in_q = self.fc_q(in_q.view(in_q.size(0), -1))
  93. in_a = self.fc_q(in_a.view(in_a.size(0), -1))
  94. score = torch.tanh(self.weight_u(self.weight_M(in_q, in_a) + self.weight_V(torch.cat((in_q, in_a), -1))))
  95. return {Const.OUTPUT: score}
  96. def predict(self, words1, words2, seq_len1, seq_len2):
  97. return self.forward(words1, words2, seq_len1, seq_len2)