You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dpcnn.py 3.4 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import torch
  2. import torch.nn as nn
  3. from fastNLP.modules.utils import get_embeddings
  4. from fastNLP.core import Const as C
  5. class DPCNN(nn.Module):
  6. def __init__(self, init_embed, num_cls, n_filters=256,
  7. kernel_size=3, n_layers=7, embed_dropout=0.1, cls_dropout=0.1):
  8. super().__init__()
  9. self.region_embed = RegionEmbedding(
  10. init_embed, out_dim=n_filters, kernel_sizes=[1, 3, 5])
  11. embed_dim = self.region_embed.embedding_dim
  12. self.conv_list = nn.ModuleList()
  13. for i in range(n_layers):
  14. self.conv_list.append(nn.Sequential(
  15. nn.ReLU(),
  16. nn.Conv1d(n_filters, n_filters, kernel_size,
  17. padding=kernel_size//2),
  18. nn.Conv1d(n_filters, n_filters, kernel_size,
  19. padding=kernel_size//2),
  20. ))
  21. self.pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
  22. self.embed_drop = nn.Dropout(embed_dropout)
  23. self.classfier = nn.Sequential(
  24. nn.Dropout(cls_dropout),
  25. nn.Linear(n_filters, num_cls),
  26. )
  27. self.reset_parameters()
  28. def reset_parameters(self):
  29. for m in self.modules():
  30. if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)):
  31. nn.init.normal_(m.weight, mean=0, std=0.01)
  32. if m.bias is not None:
  33. nn.init.normal_(m.bias, mean=0, std=0.01)
  34. def forward(self, words, seq_len=None):
  35. words = words.long()
  36. # get region embeddings
  37. x = self.region_embed(words)
  38. x = self.embed_drop(x)
  39. # not pooling on first conv
  40. x = self.conv_list[0](x) + x
  41. for conv in self.conv_list[1:]:
  42. x = self.pool(x)
  43. x = conv(x) + x
  44. # B, C, L => B, C
  45. x, _ = torch.max(x, dim=2)
  46. x = self.classfier(x)
  47. return {C.OUTPUT: x}
  48. def predict(self, words, seq_len=None):
  49. x = self.forward(words, seq_len)[C.OUTPUT]
  50. return {C.OUTPUT: torch.argmax(x, 1)}
  51. class RegionEmbedding(nn.Module):
  52. def __init__(self, init_embed, out_dim=300, kernel_sizes=None):
  53. super().__init__()
  54. if kernel_sizes is None:
  55. kernel_sizes = [5, 9]
  56. assert isinstance(
  57. kernel_sizes, list), 'kernel_sizes should be List(int)'
  58. self.embed = get_embeddings(init_embed)
  59. try:
  60. embed_dim = self.embed.embedding_dim
  61. except Exception:
  62. embed_dim = self.embed.embed_size
  63. self.region_embeds = nn.ModuleList()
  64. for ksz in kernel_sizes:
  65. self.region_embeds.append(nn.Sequential(
  66. nn.Conv1d(embed_dim, embed_dim, ksz, padding=ksz // 2),
  67. ))
  68. self.linears = nn.ModuleList([nn.Conv1d(embed_dim, out_dim, 1)
  69. for _ in range(len(kernel_sizes))])
  70. self.embedding_dim = embed_dim
  71. def forward(self, x):
  72. x = self.embed(x)
  73. x = x.transpose(1, 2)
  74. # B, C, L
  75. out = 0
  76. for conv, fc in zip(self.region_embeds, self.linears[1:]):
  77. conv_i = conv(x)
  78. out = out + fc(conv_i)
  79. # B, C, L
  80. return out
  81. if __name__ == '__main__':
  82. x = torch.randint(0, 10000, size=(5, 15), dtype=torch.long)
  83. model = DPCNN((10000, 300), 20)
  84. y = model(x)
  85. print(y.size(), y.mean(1), y.std(1))