You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

awd_lstm.py 1.0 kB

12345678910111213141516171819202122232425262728293031
  1. import torch
  2. import torch.nn as nn
  3. from fastNLP.core.const import Const as C
  4. from .awdlstm_module import LSTM
  5. from fastNLP.modules import encoder
  6. from fastNLP.modules.decoder.mlp import MLP
  7. class AWDLSTMSentiment(nn.Module):
  8. def __init__(self, init_embed,
  9. num_classes,
  10. hidden_dim=256,
  11. num_layers=1,
  12. nfc=128,
  13. wdrop=0.5):
  14. super(AWDLSTMSentiment,self).__init__()
  15. self.embed = encoder.Embedding(init_embed)
  16. self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True, wdrop=wdrop)
  17. self.mlp = MLP(size_layer=[hidden_dim* 2, nfc, num_classes])
  18. def forward(self, words):
  19. x_emb = self.embed(words)
  20. output, _ = self.lstm(x_emb)
  21. output = self.mlp(output[:,-1,:])
  22. return {C.OUTPUT: output}
  23. def predict(self, words):
  24. output = self(words)
  25. _, predict = output[C.OUTPUT].max(dim=1)
  26. return {C.OUTPUT: predict}