You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

predict.py 657 B

7 years ago
7 years ago
123456789101112131415161718192021222324
  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. class MLP(nn.Module):
  4. """
  5. A two layers perceptron for classification.
  6. Output : Unnormalized possibility distribution
  7. Args:
  8. input_size : the size of input
  9. hidden_size : the size of hidden layer
  10. output_size : the size of output
  11. """
  12. def __init__(self, input_size, hidden_size, output_size):
  13. super(MLP,self).__init__()
  14. self.L1 = nn.Linear(input_size, hidden_size)
  15. self.L2 = nn.Linear(hidden_size, output_size)
  16. def forward(self, x):
  17. out = self.L2(F.relu(self.L1(x)))
  18. return out
  19. if __name__ == "__main__":
  20. MLP(20, 30, 20)

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等