You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bert.py 801 B

123456789101112131415161718192021222324252627282930
  1. import torch
  2. import torch.nn as nn
  3. from fastNLP.core.const import Const
  4. from fastNLP.models.base_model import BaseModel
  5. from fastNLP.embeddings import BertEmbedding
  6. class BertForNLI(BaseModel):
  7. def __init__(self, bert_embed: BertEmbedding, class_num=3):
  8. super(BertForNLI, self).__init__()
  9. self.embed = bert_embed
  10. self.classifier = nn.Linear(self.embed.embedding_dim, class_num)
  11. def forward(self, words):
  12. """
  13. :param torch.Tensor words: [batch_size, seq_len] input_ids
  14. :return:
  15. """
  16. hidden = self.embed(words)
  17. logits = self.classifier(hidden)
  18. return {Const.OUTPUT: logits}
  19. def predict(self, words):
  20. logits = self.forward(words)[Const.OUTPUT]
  21. return {Const.OUTPUT: logits.argmax(dim=-1)}