| @@ -2,4 +2,4 @@ | |||
| from .base import Model | |||
| from .builder import MODELS, build_model | |||
| from .nlp import BertForSequenceClassification | |||
| from .nlp import BertForSequenceClassification, SbertForNLI | |||
| @@ -1,2 +1,3 @@ | |||
| from .nli_model import * # noqa F403 | |||
| from .sequence_classification_model import * # noqa F403 | |||
| from .text_generation_model import * # noqa F403 | |||
| @@ -0,0 +1,83 @@ | |||
| import os | |||
| from typing import Any, Dict | |||
| import numpy as np | |||
| import torch | |||
| from sofa import SbertConfig, SbertModel | |||
| from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel | |||
| from torch import nn | |||
| from transformers.activations import ACT2FN, get_activation | |||
| from transformers.models.bert.modeling_bert import SequenceClassifierOutput | |||
| from modelscope.utils.constant import Tasks | |||
| from ..base import Model, Tensor | |||
| from ..builder import MODELS | |||
| __all__ = ['SbertForNLI'] | |||
| class TextClassifier(SbertPreTrainedModel): | |||
| def __init__(self, config): | |||
| super().__init__(config) | |||
| self.num_labels = config.num_labels | |||
| self.config = config | |||
| self.encoder = SbertModel(config, add_pooling_layer=True) | |||
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |||
| def forward(self, input_ids=None, token_type_ids=None): | |||
| outputs = self.encoder( | |||
| input_ids, | |||
| token_type_ids=token_type_ids, | |||
| return_dict=None, | |||
| ) | |||
| pooled_output = outputs[1] | |||
| pooled_output = self.dropout(pooled_output) | |||
| logits = self.classifier(pooled_output) | |||
| return logits | |||
| @MODELS.register_module( | |||
| Tasks.nli, module_name=r'nlp_structbert_nli_chinese-base') | |||
| class SbertForNLI(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the text generation model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| model_cls (Optional[Any], optional): model loader, if None, use the | |||
| default loader to load model weights, by default None. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.model_dir = model_dir | |||
| self.model = TextClassifier.from_pretrained(model_dir, num_labels=3) | |||
| self.model.eval() | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Any]): the preprocessed data | |||
| Returns: | |||
| Dict[str, np.ndarray]: results | |||
| Example: | |||
| { | |||
| 'predictions': array([1]), # lable 0-negative 1-positive | |||
| 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
| } | |||
| """ | |||
| input_ids = torch.tensor(input['input_ids'], dtype=torch.long) | |||
| token_type_ids = torch.tensor( | |||
| input['token_type_ids'], dtype=torch.long) | |||
| with torch.no_grad(): | |||
| logits = self.model(input_ids, token_type_ids) | |||
| probs = logits.softmax(-1).numpy() | |||
| pred = logits.argmax(-1).numpy() | |||
| logits = logits.numpy() | |||
| res = {'predictions': pred, 'probabilities': probs, 'logits': logits} | |||
| return res | |||
| @@ -1,2 +1,3 @@ | |||
| from .nli_pipeline import * # noqa F403 | |||
| from .sequence_classification_pipeline import * # noqa F403 | |||
| from .text_generation_pipeline import * # noqa F403 | |||
| @@ -0,0 +1,88 @@ | |||
| import os | |||
| import uuid | |||
| from typing import Any, Dict, Union | |||
| import json | |||
| import numpy as np | |||
| from modelscope.models.nlp import SbertForNLI | |||
| from modelscope.preprocessors import NLIPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from ...models import Model | |||
| from ..base import Input, Pipeline | |||
| from ..builder import PIPELINES | |||
| __all__ = ['NLIPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.nli, module_name=r'nlp_structbert_nli_chinese-base') | |||
| class NLIPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[SbertForNLI, str], | |||
| preprocessor: NLIPreprocessor = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
| Args: | |||
| model (SbertForNLI): a model instance | |||
| preprocessor (NLIPreprocessor): a preprocessor instance | |||
| """ | |||
| assert isinstance(model, str) or isinstance(model, SbertForNLI), \ | |||
| 'model must be a single str or SbertForNLI' | |||
| sc_model = model if isinstance(model, | |||
| SbertForNLI) else SbertForNLI(model) | |||
| if preprocessor is None: | |||
| preprocessor = NLIPreprocessor( | |||
| sc_model.model_dir, | |||
| first_sequence='first_sequence', | |||
| second_sequence='second_sequence') | |||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||
| self.label_path = os.path.join(sc_model.model_dir, | |||
| 'label_mapping.json') | |||
| with open(self.label_path) as f: | |||
| self.label_mapping = json.load(f) | |||
| self.label_id_to_name = { | |||
| idx: name | |||
| for name, idx in self.label_mapping.items() | |||
| } | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| probs = inputs['probabilities'] | |||
| logits = inputs['logits'] | |||
| predictions = np.argsort(-probs, axis=-1) | |||
| preds = predictions[0] | |||
| b = 0 | |||
| new_result = list() | |||
| for pred in preds: | |||
| new_result.append({ | |||
| 'pred': self.label_id_to_name[pred], | |||
| 'prob': float(probs[b][pred]), | |||
| 'logit': float(logits[b][pred]) | |||
| }) | |||
| new_results = list() | |||
| new_results.append({ | |||
| 'id': | |||
| inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()), | |||
| 'output': | |||
| new_result, | |||
| 'predictions': | |||
| new_result[0]['pred'], | |||
| 'probabilities': | |||
| ','.join([str(t) for t in inputs['probabilities'][b]]), | |||
| 'logits': | |||
| ','.join([str(t) for t in inputs['logits'][b]]) | |||
| }) | |||
| return new_results[0] | |||
| @@ -5,4 +5,4 @@ from .builder import PREPROCESSORS, build_preprocessor | |||
| from .common import Compose | |||
| from .image import LoadImage, load_image | |||
| from .nlp import * # noqa F403 | |||
| from .nlp import TextGenerationPreprocessor | |||
| from .nlp import NLIPreprocessor, TextGenerationPreprocessor | |||
| @@ -10,7 +10,7 @@ from modelscope.utils.type_assert import type_assert | |||
| from .base import Preprocessor | |||
| from .builder import PREPROCESSORS | |||
| __all__ = ['Tokenize', 'SequenceClassificationPreprocessor'] | |||
| __all__ = ['Tokenize', 'SequenceClassificationPreprocessor', 'NLIPreprocessor'] | |||
| @PREPROCESSORS.register_module(Fields.nlp) | |||
| @@ -27,6 +27,77 @@ class Tokenize(Preprocessor): | |||
| return data | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=r'nlp_structbert_nli_chinese-base') | |||
| class NLIPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| from sofa import SbertTokenizer | |||
| self.model_dir: str = model_dir | |||
| self.first_sequence: str = kwargs.pop('first_sequence', | |||
| 'first_sequence') | |||
| self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') | |||
| self.sequence_length = kwargs.pop('sequence_length', 128) | |||
| self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) | |||
| @type_assert(object, tuple) | |||
| def __call__(self, data: tuple) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (tuple): [sentence1, sentence2] | |||
| sentence1 (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| sentence2 (str): a sentence | |||
| Example: | |||
| 'you are so beautiful.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| sentence1, sentence2 = data | |||
| new_data = { | |||
| self.first_sequence: sentence1, | |||
| self.second_sequence: sentence2 | |||
| } | |||
| # preprocess the data for the model input | |||
| rst = { | |||
| 'id': [], | |||
| 'input_ids': [], | |||
| 'attention_mask': [], | |||
| 'token_type_ids': [] | |||
| } | |||
| max_seq_length = self.sequence_length | |||
| text_a = new_data[self.first_sequence] | |||
| text_b = new_data[self.second_sequence] | |||
| feature = self.tokenizer( | |||
| text_a, | |||
| text_b, | |||
| padding=False, | |||
| truncation=True, | |||
| max_length=max_seq_length) | |||
| rst['id'].append(new_data.get('id', str(uuid.uuid4()))) | |||
| rst['input_ids'].append(feature['input_ids']) | |||
| rst['attention_mask'].append(feature['attention_mask']) | |||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||
| return rst | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=r'bert-sentiment-analysis') | |||
| class SequenceClassificationPreprocessor(Preprocessor): | |||
| @@ -30,6 +30,7 @@ class Tasks(object): | |||
| image_matting = 'image-matting' | |||
| # nlp tasks | |||
| nli = 'nli' | |||
| sentiment_analysis = 'sentiment-analysis' | |||
| text_classification = 'text-classification' | |||
| relation_extraction = 'relation-extraction' | |||
| @@ -0,0 +1,12 @@ | |||
| from modelscope.models import SbertForNLI | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.preprocessors import NLIPreprocessor | |||
| model = SbertForNLI('../nlp_structbert_nli_chinese-base') | |||
| print(model) | |||
| tokenizer = NLIPreprocessor(model.model_dir) | |||
| semantic_cls = pipeline('nli', model=model, preprocessor=tokenizer) | |||
| print(type(semantic_cls)) | |||
| print(semantic_cls(input=('相反,这表明克林顿的敌人是疯子。', '四川商务职业学院商务管理在哪个校区?'))) | |||