You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lstm_self_attention.py 1.4 kB

1234567891011121314151617181920212223242526272829303132333435
  1. import torch
  2. import torch.nn as nn
  3. from fastNLP.core.const import Const as C
  4. from fastNLP.modules.encoder.lstm import LSTM
  5. from fastNLP.modules import encoder
  6. from fastNLP.modules.aggregator.attention import SelfAttention
  7. from fastNLP.modules.decoder.mlp import MLP
  8. class BiLSTM_SELF_ATTENTION(nn.Module):
  9. def __init__(self, init_embed,
  10. num_classes,
  11. hidden_dim=256,
  12. num_layers=1,
  13. attention_unit=256,
  14. attention_hops=1,
  15. nfc=128):
  16. super(BiLSTM_SELF_ATTENTION,self).__init__()
  17. self.embed = encoder.Embedding(init_embed)
  18. self.lstm = LSTM(input_size=self.embed.embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, bidirectional=True)
  19. self.attention = SelfAttention(input_size=hidden_dim * 2 , attention_unit=attention_unit, attention_hops=attention_hops)
  20. self.mlp = MLP(size_layer=[hidden_dim* 2*attention_hops, nfc, num_classes])
  21. def forward(self, words):
  22. x_emb = self.embed(words)
  23. output, _ = self.lstm(x_emb)
  24. after_attention, penalty = self.attention(output,words)
  25. after_attention =after_attention.view(after_attention.size(0),-1)
  26. output = self.mlp(after_attention)
  27. return {C.OUTPUT: output}
  28. def predict(self, words):
  29. output = self(words)
  30. _, predict = output[C.OUTPUT].max(dim=1)
  31. return {C.OUTPUT: predict}