You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

BertTC.py 818 B

123456789101112131415161718192021222324
  1. from fastNLP.embeddings import BertEmbedding
  2. import torch
  3. import torch.nn as nn
  4. from fastNLP.core.const import Const as C
  5. class BertTC(nn.Module):
  6. def __init__(self, vocab,num_class,bert_model_dir_or_name,fine_tune=False):
  7. super(BertTC, self).__init__()
  8. self.embed=BertEmbedding(vocab, requires_grad=fine_tune,
  9. model_dir_or_name=bert_model_dir_or_name,include_cls_sep=True)
  10. self.classifier = nn.Linear(self.embed.embedding_dim, num_class)
  11. def forward(self, words):
  12. embedding_cls=self.embed(words)[:,0]
  13. output=self.classifier(embedding_cls)
  14. return {C.OUTPUT: output}
  15. def predict(self,words):
  16. return self.forward(words)
  17. if __name__=="__main__":
  18. ta=torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
  19. tb=ta[:,0]
  20. print(tb)