You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 1.4 kB

1234567891011121314151617181920212223242526272829303132333435
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class CNN_text(nn.Module):
  5. def __init__(self, kernel_h=[3, 4, 5], kernel_num=100, embed_num=1000, embed_dim=300, dropout=0.5, L2_constrain=3,
  6. batchsize=50, pretrained_embeddings=None):
  7. super(CNN_text, self).__init__()
  8. self.embedding = nn.Embedding(embed_num, embed_dim)
  9. self.dropout = nn.Dropout(dropout)
  10. if pretrained_embeddings is not None:
  11. self.embedding.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
  12. # the network structure
  13. # Conv2d: input- N,C,H,W output- (50,100,62,1)
  14. self.conv1 = nn.ModuleList([nn.Conv2d(1, 100, (K, 300)) for K in kernel_h])
  15. self.fc1 = nn.Linear(300, 2)
  16. def max_pooling(self, x):
  17. x = F.relu(conv(x)).squeeze(3) # N,C,L - (50,100,62)
  18. x = F.max_pool1d(x, x.size(2)).squeeze(2)
  19. # x.size(2)=62 squeeze: (50,100,1) -> (50,100)
  20. return x
  21. def forward(self, x):
  22. x = self.embedding(x) # output: (N,H,W) = (50,64,300)
  23. x = x.unsqueeze(1) # (N,C,H,W)
  24. x = [F.relu(conv(x)).squeeze(3) for conv in self.conv1] # [N, C, H(50,100,62),(50,100,61),(50,100,60)]
  25. x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [N,C(50,100),(50,100),(50,100)]
  26. x = torch.cat(x, 1)
  27. x = self.dropout(x)
  28. x = self.fc1(x)
  29. return x

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等