You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bert_crf.py 1.1 kB

12345678910111213141516171819202122232425262728293031
  1. from torch import nn
  2. from fastNLP.modules import ConditionalRandomField, allowed_transitions
  3. import torch.nn.functional as F
  4. class BertCRF(nn.Module):
  5. def __init__(self, embed, tag_vocab, encoding_type='bio'):
  6. super().__init__()
  7. self.embed = embed
  8. self.fc = nn.Linear(self.embed.embed_size, len(tag_vocab))
  9. trans = allowed_transitions(tag_vocab, encoding_type=encoding_type, include_start_end=True)
  10. self.crf = ConditionalRandomField(len(tag_vocab), include_start_end_trans=True, allowed_transitions=trans)
  11. def _forward(self, words, target):
  12. mask = words.ne(0)
  13. words = self.embed(words)
  14. words = self.fc(words)
  15. logits = F.log_softmax(words, dim=-1)
  16. if target is not None:
  17. loss = self.crf(logits, target, mask)
  18. return {'loss': loss}
  19. else:
  20. paths, _ = self.crf.viterbi_decode(logits, mask)
  21. return {'pred': paths}
  22. def forward(self, words, target):
  23. return self._forward(words, target)
  24. def predict(self, words):
  25. return self._forward(words, None)