You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 1.6 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class CNN_text(nn.Module):
  5. def __init__(self, kernel_h=[3, 4, 5], kernel_num=100, embed_num=1000, embed_dim=300, num_classes=2, dropout=0.5,
  6. L2_constrain=3,
  7. pretrained_embeddings=None):
  8. super(CNN_text, self).__init__()
  9. self.embedding = nn.Embedding(embed_num, embed_dim)
  10. self.dropout = nn.Dropout(dropout)
  11. if pretrained_embeddings is not None:
  12. self.embedding.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
  13. # the network structure
  14. # Conv2d: input- N,C,H,W output- (50,100,62,1)
  15. self.conv1 = nn.ModuleList([nn.Conv2d(1, kernel_num, (K, embed_dim)) for K in kernel_h])
  16. self.fc1 = nn.Linear(len(kernel_h) * kernel_num, num_classes)
  17. def max_pooling(self, x):
  18. x = F.relu(self.conv1(x)).squeeze(3) # N,C,L - (50,100,62)
  19. x = F.max_pool1d(x, x.size(2)).squeeze(2)
  20. # x.size(2)=62 squeeze: (50,100,1) -> (50,100)
  21. return x
  22. def forward(self, x):
  23. x = self.embedding(x) # output: (N,H,W) = (50,64,300)
  24. x = x.unsqueeze(1) # (N,C,H,W)
  25. x = [F.relu(conv(x)).squeeze(3) for conv in self.conv1] # [N, C, H(50,100,62),(50,100,61),(50,100,60)]
  26. x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [N,C(50,100),(50,100),(50,100)]
  27. x = torch.cat(x, 1)
  28. x = self.dropout(x)
  29. x = self.fc1(x)
  30. return x
  31. if __name__ == '__main__':
  32. model = CNN_text(kernel_h=[1, 2, 3, 4], embed_num=3, embed_dim=2)
  33. x = torch.LongTensor([[1, 2, 1, 2, 0]])
  34. print(model(x))