You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import torch
  2. from torch import nn
  3. from torch.nn import init
  4. from fastNLP.modules.encoder.bert import BertModel
  5. class Classifier(nn.Module):
  6. def __init__(self, hidden_size):
  7. super(Classifier, self).__init__()
  8. self.linear = nn.Linear(hidden_size, 1)
  9. self.sigmoid = nn.Sigmoid()
  10. def forward(self, inputs, mask_cls):
  11. h = self.linear(inputs).squeeze(-1) # [batch_size, seq_len]
  12. sent_scores = self.sigmoid(h) * mask_cls.float()
  13. return sent_scores
  14. class BertSum(nn.Module):
  15. def __init__(self, hidden_size=768):
  16. super(BertSum, self).__init__()
  17. self.hidden_size = hidden_size
  18. self.encoder = BertModel.from_pretrained('/path/to/uncased_L-12_H-768_A-12')
  19. self.decoder = Classifier(self.hidden_size)
  20. def forward(self, article, segment_id, cls_id):
  21. # print(article.device)
  22. # print(segment_id.device)
  23. # print(cls_id.device)
  24. input_mask = 1 - (article == 0).long()
  25. mask_cls = 1 - (cls_id == -1).long()
  26. assert input_mask.size() == article.size()
  27. assert mask_cls.size() == cls_id.size()
  28. bert_out = self.encoder(article, token_type_ids=segment_id, attention_mask=input_mask)
  29. bert_out = bert_out[0][-1] # last layer
  30. sent_emb = bert_out[torch.arange(bert_out.size(0)).unsqueeze(1), cls_id]
  31. sent_emb = sent_emb * mask_cls.unsqueeze(-1).float()
  32. assert sent_emb.size() == (article.size(0), cls_id.size(1), self.hidden_size) # [batch_size, seq_len, hidden_size]
  33. sent_scores = self.decoder(sent_emb, mask_cls) # [batch_size, seq_len]
  34. assert sent_scores.size() == (article.size(0), cls_id.size(1))
  35. return {'pred': sent_scores, 'mask': mask_cls}