import torch import torch.nn as nn from fastNLP.core.const import Const from fastNLP.models.base_model import BaseModel from fastNLP.embeddings import BertEmbedding class BertForNLI(BaseModel): def __init__(self, bert_embed: BertEmbedding, class_num=3): super(BertForNLI, self).__init__() self.embed = bert_embed self.classifier = nn.Linear(self.embed.embedding_dim, class_num) def forward(self, words): """ :param torch.Tensor words: [batch_size, seq_len] input_ids :return: """ hidden = self.embed(words) logits = self.classifier(hidden) return {Const.OUTPUT: logits} def predict(self, words): logits = self.forward(words)[Const.OUTPUT] return {Const.OUTPUT: logits.argmax(dim=-1)}