|
|
|
@@ -4,6 +4,7 @@ from ..base import Model |
|
|
|
import numpy as np |
|
|
|
import json |
|
|
|
import os |
|
|
|
import torch |
|
|
|
from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel, SbertModel |
|
|
|
|
|
|
|
|
|
|
|
@@ -33,9 +34,11 @@ class SbertTextClassfier(SbertPreTrainedModel): |
|
|
|
|
|
|
|
class SbertForSequenceClassificationBase(Model): |
|
|
|
|
|
|
|
def __init__(self, model_dir: str, *args, **kwargs): |
|
|
|
def __init__(self, model_dir: str, model_args=None, *args, **kwargs): |
|
|
|
super().__init__(model_dir, *args, **kwargs) |
|
|
|
self.model = SbertTextClassfier.from_pretrained(model_dir) |
|
|
|
if model_args is None: |
|
|
|
model_args = {} |
|
|
|
self.model = SbertTextClassfier.from_pretrained(model_dir, **model_args) |
|
|
|
self.id2label = {} |
|
|
|
self.label_path = os.path.join(self.model_dir, 'label_mapping.json') |
|
|
|
if os.path.exists(self.label_path): |
|
|
|
@@ -43,8 +46,17 @@ class SbertForSequenceClassificationBase(Model): |
|
|
|
self.label_mapping = json.load(f) |
|
|
|
self.id2label = {idx: name for name, idx in self.label_mapping.items()} |
|
|
|
|
|
|
|
def train(self): |
|
|
|
return self.model.train() |
|
|
|
|
|
|
|
def eval(self): |
|
|
|
return self.model.eval() |
|
|
|
|
|
|
|
def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: |
|
|
|
return self.model.forward(input) |
|
|
|
input_ids = torch.tensor(input['input_ids'], dtype=torch.long) |
|
|
|
token_type_ids = torch.tensor( |
|
|
|
input['token_type_ids'], dtype=torch.long) |
|
|
|
return self.model.forward(input_ids, token_type_ids) |
|
|
|
|
|
|
|
def postprocess(self, input, **kwargs): |
|
|
|
logits = input["logits"] |
|
|
|
|