| @@ -17,5 +17,5 @@ class SbertForNLI(SbertForSequenceClassificationBase): | |||||
| model_cls (Optional[Any], optional): model loader, if None, use the | model_cls (Optional[Any], optional): model loader, if None, use the | ||||
| default loader to load model weights, by default None. | default loader to load model weights, by default None. | ||||
| """ | """ | ||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| super().__init__(model_dir, *args, model_args={"num_labels": 3}, **kwargs) | |||||
| assert self.model.config.num_labels == 3 | assert self.model.config.num_labels == 3 | ||||
| @@ -1,13 +1,14 @@ | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from .sbert_for_sequence_classification import SbertForSequenceClassificationBase | from .sbert_for_sequence_classification import SbertForSequenceClassificationBase | ||||
| from ..builder import MODELS | from ..builder import MODELS | ||||
| from modelscope.metainfo import Models | |||||
| __all__ = ['SbertForSentimentClassification'] | __all__ = ['SbertForSentimentClassification'] | ||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| Tasks.sentiment_classification, | Tasks.sentiment_classification, | ||||
| module_name=r'sbert-sentiment-classification') | |||||
| module_name=Models.structbert) | |||||
| class SbertForSentimentClassification(SbertForSequenceClassificationBase): | class SbertForSentimentClassification(SbertForSequenceClassificationBase): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| @@ -4,6 +4,7 @@ from ..base import Model | |||||
| import numpy as np | import numpy as np | ||||
| import json | import json | ||||
| import os | import os | ||||
| import torch | |||||
| from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel, SbertModel | from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel, SbertModel | ||||
| @@ -33,9 +34,11 @@ class SbertTextClassfier(SbertPreTrainedModel): | |||||
| class SbertForSequenceClassificationBase(Model): | class SbertForSequenceClassificationBase(Model): | ||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| def __init__(self, model_dir: str, model_args=None, *args, **kwargs): | |||||
| super().__init__(model_dir, *args, **kwargs) | super().__init__(model_dir, *args, **kwargs) | ||||
| self.model = SbertTextClassfier.from_pretrained(model_dir) | |||||
| if model_args is None: | |||||
| model_args = {} | |||||
| self.model = SbertTextClassfier.from_pretrained(model_dir, **model_args) | |||||
| self.id2label = {} | self.id2label = {} | ||||
| self.label_path = os.path.join(self.model_dir, 'label_mapping.json') | self.label_path = os.path.join(self.model_dir, 'label_mapping.json') | ||||
| if os.path.exists(self.label_path): | if os.path.exists(self.label_path): | ||||
| @@ -43,8 +46,17 @@ class SbertForSequenceClassificationBase(Model): | |||||
| self.label_mapping = json.load(f) | self.label_mapping = json.load(f) | ||||
| self.id2label = {idx: name for name, idx in self.label_mapping.items()} | self.id2label = {idx: name for name, idx in self.label_mapping.items()} | ||||
| def train(self): | |||||
| return self.model.train() | |||||
| def eval(self): | |||||
| return self.model.eval() | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | ||||
| return self.model.forward(input) | |||||
| input_ids = torch.tensor(input['input_ids'], dtype=torch.long) | |||||
| token_type_ids = torch.tensor( | |||||
| input['token_type_ids'], dtype=torch.long) | |||||
| return self.model.forward(input_ids, token_type_ids) | |||||
| def postprocess(self, input, **kwargs): | def postprocess(self, input, **kwargs): | ||||
| logits = input["logits"] | logits = input["logits"] | ||||
| @@ -26,6 +26,12 @@ class SbertForZeroShotClassification(Model): | |||||
| from sofa import SbertForSequenceClassification | from sofa import SbertForSequenceClassification | ||||
| self.model = SbertForSequenceClassification.from_pretrained(model_dir) | self.model = SbertForSequenceClassification.from_pretrained(model_dir) | ||||
| def train(self): | |||||
| return self.model.train() | |||||
| def eval(self): | |||||
| return self.model.eval() | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | ||||
| """return the result by the model | """return the result by the model | ||||
| @@ -69,7 +69,7 @@ class NLIPipeline(Pipeline): | |||||
| new_result = list() | new_result = list() | ||||
| for pred in preds: | for pred in preds: | ||||
| new_result.append({ | new_result.append({ | ||||
| 'pred': self.label_id_to_name[pred], | |||||
| 'pred': self.model.id2label[pred], | |||||
| 'prob': float(probs[b][pred]), | 'prob': float(probs[b][pred]), | ||||
| 'logit': float(logits[b][pred]) | 'logit': float(logits[b][pred]) | ||||
| }) | }) | ||||
| @@ -70,7 +70,7 @@ class SentimentClassificationPipeline(Pipeline): | |||||
| new_result = list() | new_result = list() | ||||
| for pred in preds: | for pred in preds: | ||||
| new_result.append({ | new_result.append({ | ||||
| 'pred': self.label_id_to_name[pred], | |||||
| 'pred': self.model.id2label[pred], | |||||
| 'prob': float(probs[b][pred]), | 'prob': float(probs[b][pred]), | ||||
| 'logit': float(logits[b][pred]) | 'logit': float(logits[b][pred]) | ||||
| }) | }) | ||||
| @@ -44,7 +44,7 @@ class ZeroShotClassificationPipeline(Pipeline): | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = ZeroShotClassificationPreprocessor( | preprocessor = ZeroShotClassificationPreprocessor( | ||||
| sc_model.model_dir) | sc_model.model_dir) | ||||
| model.eval() | |||||
| sc_model.eval() | |||||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | ||||
| def _sanitize_parameters(self, **kwargs): | def _sanitize_parameters(self, **kwargs): | ||||