diff --git a/modelscope/models/__init__.py b/modelscope/models/__init__.py index 170e525e..2d852970 100644 --- a/modelscope/models/__init__.py +++ b/modelscope/models/__init__.py @@ -2,4 +2,4 @@ from .base import Model from .builder import MODELS, build_model -from .nlp import BertForSequenceClassification +from .nlp import BertForSequenceClassification, SbertForNLI diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index b2a1d43b..114295fc 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -1,2 +1,3 @@ +from .nli_model import * # noqa F403 from .sequence_classification_model import * # noqa F403 from .text_generation_model import * # noqa F403 diff --git a/modelscope/models/nlp/nli_model.py b/modelscope/models/nlp/nli_model.py new file mode 100644 index 00000000..05166bd0 --- /dev/null +++ b/modelscope/models/nlp/nli_model.py @@ -0,0 +1,83 @@ +import os +from typing import Any, Dict + +import numpy as np +import torch +from sofa import SbertConfig, SbertModel +from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel +from torch import nn +from transformers.activations import ACT2FN, get_activation +from transformers.models.bert.modeling_bert import SequenceClassifierOutput + +from modelscope.utils.constant import Tasks +from ..base import Model, Tensor +from ..builder import MODELS + +__all__ = ['SbertForNLI'] + + +class TextClassifier(SbertPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.encoder = SbertModel(config, add_pooling_layer=True) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, input_ids=None, token_type_ids=None): + outputs = self.encoder( + input_ids, + token_type_ids=token_type_ids, + return_dict=None, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + return logits + + +@MODELS.register_module( + Tasks.nli, module_name=r'nlp_structbert_nli_chinese-base') +class SbertForNLI(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the text generation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + model_cls (Optional[Any], optional): model loader, if None, use the + default loader to load model weights, by default None. + """ + super().__init__(model_dir, *args, **kwargs) + self.model_dir = model_dir + + self.model = TextClassifier.from_pretrained(model_dir, num_labels=3) + self.model.eval() + + def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: + """return the result by the model + + Args: + input (Dict[str, Any]): the preprocessed data + + Returns: + Dict[str, np.ndarray]: results + Example: + { + 'predictions': array([1]), # lable 0-negative 1-positive + 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), + 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + } + """ + input_ids = torch.tensor(input['input_ids'], dtype=torch.long) + token_type_ids = torch.tensor( + input['token_type_ids'], dtype=torch.long) + with torch.no_grad(): + logits = self.model(input_ids, token_type_ids) + probs = logits.softmax(-1).numpy() + pred = logits.argmax(-1).numpy() + logits = logits.numpy() + res = {'predictions': pred, 'probabilities': probs, 'logits': logits} + return res diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 3dbbc1bb..e9c5ab98 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -1,2 +1,3 @@ +from .nli_pipeline import * # noqa F403 from .sequence_classification_pipeline import * # noqa F403 from .text_generation_pipeline import * # noqa F403 diff --git a/modelscope/pipelines/nlp/nli_pipeline.py b/modelscope/pipelines/nlp/nli_pipeline.py new file mode 100644 index 00000000..fe658c77 --- /dev/null +++ b/modelscope/pipelines/nlp/nli_pipeline.py @@ -0,0 +1,88 @@ +import os +import uuid +from typing import Any, Dict, Union + +import json +import numpy as np + +from modelscope.models.nlp import SbertForNLI +from modelscope.preprocessors import NLIPreprocessor +from modelscope.utils.constant import Tasks +from ...models import Model +from ..base import Input, Pipeline +from ..builder import PIPELINES + +__all__ = ['NLIPipeline'] + + +@PIPELINES.register_module( + Tasks.nli, module_name=r'nlp_structbert_nli_chinese-base') +class NLIPipeline(Pipeline): + + def __init__(self, + model: Union[SbertForNLI, str], + preprocessor: NLIPreprocessor = None, + **kwargs): + """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + + Args: + model (SbertForNLI): a model instance + preprocessor (NLIPreprocessor): a preprocessor instance + """ + assert isinstance(model, str) or isinstance(model, SbertForNLI), \ + 'model must be a single str or SbertForNLI' + sc_model = model if isinstance(model, + SbertForNLI) else SbertForNLI(model) + if preprocessor is None: + preprocessor = NLIPreprocessor( + sc_model.model_dir, + first_sequence='first_sequence', + second_sequence='second_sequence') + super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) + + self.label_path = os.path.join(sc_model.model_dir, + 'label_mapping.json') + with open(self.label_path) as f: + self.label_mapping = json.load(f) + self.label_id_to_name = { + idx: name + for name, idx in self.label_mapping.items() + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + + probs = inputs['probabilities'] + logits = inputs['logits'] + predictions = np.argsort(-probs, axis=-1) + preds = predictions[0] + b = 0 + new_result = list() + for pred in preds: + new_result.append({ + 'pred': self.label_id_to_name[pred], + 'prob': float(probs[b][pred]), + 'logit': float(logits[b][pred]) + }) + new_results = list() + new_results.append({ + 'id': + inputs['id'][b] if 'id' in inputs else str(uuid.uuid4()), + 'output': + new_result, + 'predictions': + new_result[0]['pred'], + 'probabilities': + ','.join([str(t) for t in inputs['probabilities'][b]]), + 'logits': + ','.join([str(t) for t in inputs['logits'][b]]) + }) + + return new_results[0] diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 518ea977..47a713ff 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -5,4 +5,4 @@ from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image from .nlp import * # noqa F403 -from .nlp import TextGenerationPreprocessor +from .nlp import NLIPreprocessor, TextGenerationPreprocessor diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 0de41bfc..37442dcb 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -10,7 +10,7 @@ from modelscope.utils.type_assert import type_assert from .base import Preprocessor from .builder import PREPROCESSORS -__all__ = ['Tokenize', 'SequenceClassificationPreprocessor'] +__all__ = ['Tokenize', 'SequenceClassificationPreprocessor', 'NLIPreprocessor'] @PREPROCESSORS.register_module(Fields.nlp) @@ -27,6 +27,77 @@ class Tokenize(Preprocessor): return data +@PREPROCESSORS.register_module( + Fields.nlp, module_name=r'nlp_structbert_nli_chinese-base') +class NLIPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + + super().__init__(*args, **kwargs) + + from sofa import SbertTokenizer + self.model_dir: str = model_dir + self.first_sequence: str = kwargs.pop('first_sequence', + 'first_sequence') + self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') + self.sequence_length = kwargs.pop('sequence_length', 128) + + self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) + + @type_assert(object, tuple) + def __call__(self, data: tuple) -> Dict[str, Any]: + """process the raw input data + + Args: + data (tuple): [sentence1, sentence2] + sentence1 (str): a sentence + Example: + 'you are so handsome.' + sentence2 (str): a sentence + Example: + 'you are so beautiful.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + sentence1, sentence2 = data + new_data = { + self.first_sequence: sentence1, + self.second_sequence: sentence2 + } + # preprocess the data for the model input + + rst = { + 'id': [], + 'input_ids': [], + 'attention_mask': [], + 'token_type_ids': [] + } + + max_seq_length = self.sequence_length + + text_a = new_data[self.first_sequence] + text_b = new_data[self.second_sequence] + feature = self.tokenizer( + text_a, + text_b, + padding=False, + truncation=True, + max_length=max_seq_length) + + rst['id'].append(new_data.get('id', str(uuid.uuid4()))) + rst['input_ids'].append(feature['input_ids']) + rst['attention_mask'].append(feature['attention_mask']) + rst['token_type_ids'].append(feature['token_type_ids']) + + return rst + + @PREPROCESSORS.register_module( Fields.nlp, module_name=r'bert-sentiment-analysis') class SequenceClassificationPreprocessor(Preprocessor): diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index c51e2445..a955574b 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -30,6 +30,7 @@ class Tasks(object): image_matting = 'image-matting' # nlp tasks + nli = 'nli' sentiment_analysis = 'sentiment-analysis' text_classification = 'text-classification' relation_extraction = 'relation-extraction' diff --git a/test.py b/test.py new file mode 100644 index 00000000..d0cd093b --- /dev/null +++ b/test.py @@ -0,0 +1,12 @@ +from modelscope.models import SbertForNLI +from modelscope.pipelines import pipeline +from modelscope.preprocessors import NLIPreprocessor + +model = SbertForNLI('../nlp_structbert_nli_chinese-base') +print(model) +tokenizer = NLIPreprocessor(model.model_dir) + +semantic_cls = pipeline('nli', model=model, preprocessor=tokenizer) +print(type(semantic_cls)) + +print(semantic_cls(input=('相反,这表明克林顿的敌人是疯子。', '四川商务职业学院商务管理在哪个校区?')))