| @@ -1,2 +1,3 @@ | |||
| from .sequence_classification_model import * # noqa F403 | |||
| from .text_generation_model import * # noqa F403 | |||
| from .zero_shot_classification_model import * | |||
| @@ -0,0 +1,45 @@ | |||
| from typing import Any, Dict | |||
| import torch | |||
| import numpy as np | |||
| from modelscope.utils.constant import Tasks | |||
| from ..base import Model | |||
| from ..builder import MODELS | |||
| __all__ = ['BertForZeroShotClassification'] | |||
| @MODELS.register_module( | |||
| Tasks.zero_shot_classification, module_name=r'bert-zero-shot-classification') | |||
| class BertForZeroShotClassification(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the zero shot classification model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| from sofa import SbertForSequenceClassification | |||
| self.model = SbertForSequenceClassification.from_pretrained(model_dir) | |||
| self.model.eval() | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Any]): the preprocessed data | |||
| Returns: | |||
| Dict[str, np.ndarray]: results | |||
| Example: | |||
| { | |||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
| } | |||
| """ | |||
| with torch.no_grad(): | |||
| outputs = self.model(**input) | |||
| logits = outputs["logits"].numpy() | |||
| res = {'logits': logits} | |||
| return res | |||
| @@ -20,6 +20,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.image_matting: ('image-matting', 'damo/image-matting-person'), | |||
| Tasks.text_classification: | |||
| ('bert-sentiment-analysis', 'damo/bert-base-sst2'), | |||
| Tasks.zero_shot_classification: | |||
| ('bert-zero-shot-classification', 'damo/nlp_structbert_zero-shot-classification_chinese-base'), | |||
| Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), | |||
| Tasks.image_captioning: ('ofa', None), | |||
| Tasks.image_generation: | |||
| @@ -1,2 +1,3 @@ | |||
| from .sequence_classification_pipeline import * # noqa F403 | |||
| from .text_generation_pipeline import * # noqa F403 | |||
| from .zero_shot_classification_pipeline import * | |||
| @@ -0,0 +1,78 @@ | |||
| import os | |||
| import uuid | |||
| from typing import Any, Dict, Union | |||
| import json | |||
| import numpy as np | |||
| from modelscope.models.nlp import BertForZeroShotClassification | |||
| from modelscope.preprocessors import ZeroShotClassificationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from ...models import Model | |||
| from ..base import Input, Pipeline | |||
| from ..builder import PIPELINES | |||
| from scipy.special import softmax | |||
| __all__ = ['ZeroShotClassificationPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.zero_shot_classification, | |||
| module_name=r'bert-zero-shot-classification') | |||
| class ZeroShotClassificationPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[BertForZeroShotClassification, str], | |||
| preprocessor: ZeroShotClassificationPreprocessor = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
| Args: | |||
| model (SbertForSentimentClassification): a model instance | |||
| preprocessor (SentimentClassificationPreprocessor): a preprocessor instance | |||
| """ | |||
| assert isinstance(model, str) or isinstance(model, BertForZeroShotClassification), \ | |||
| 'model must be a single str or BertForZeroShotClassification' | |||
| sc_model = model if isinstance( | |||
| model, | |||
| BertForZeroShotClassification) else Model.from_pretrained(model) | |||
| self.entailment_id = 0 | |||
| self.contradiction_id = 2 | |||
| self.candidate_labels = kwargs.pop("candidate_labels") | |||
| self.hypothesis_template = kwargs.pop('hypothesis_template', "{}") | |||
| self.multi_label = kwargs.pop('multi_label', False) | |||
| if preprocessor is None: | |||
| preprocessor = ZeroShotClassificationPreprocessor( | |||
| sc_model.model_dir, | |||
| candidate_labels=self.candidate_labels, | |||
| hypothesis_template=self.hypothesis_template | |||
| ) | |||
| super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, Any]: the prediction results | |||
| """ | |||
| logits = inputs['logits'] | |||
| if self.multi_label or len(self.candidate_labels) == 1: | |||
| logits = logits[..., [self.contradiction_id, self.entailment_id]] | |||
| scores = softmax(logits, axis=-1)[..., 1] | |||
| else: | |||
| logits = logits[..., self.entailment_id] | |||
| scores = softmax(logits, axis=-1) | |||
| reversed_index = list(reversed(scores.argsort())) | |||
| result = { | |||
| "labels": [self.candidate_labels[i] for i in reversed_index], | |||
| "scores": [scores[i].item() for i in reversed_index], | |||
| } | |||
| return result | |||
| @@ -5,4 +5,4 @@ from .builder import PREPROCESSORS, build_preprocessor | |||
| from .common import Compose | |||
| from .image import LoadImage, load_image | |||
| from .nlp import * # noqa F403 | |||
| from .nlp import TextGenerationPreprocessor | |||
| from .nlp import TextGenerationPreprocessor, ZeroShotClassificationPreprocessor | |||
| @@ -147,3 +147,49 @@ class TextGenerationPreprocessor(Preprocessor): | |||
| rst['token_type_ids'].append(feature['token_type_ids']) | |||
| return {k: torch.tensor(v) for k, v in rst.items()} | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=r'bert-zero-shot-classification') | |||
| class ZeroShotClassificationPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| from sofa import SbertTokenizer | |||
| self.model_dir: str = model_dir | |||
| self.sequence_length = kwargs.pop('sequence_length', 512) | |||
| self.candidate_labels = kwargs.pop("candidate_labels") | |||
| self.hypothesis_template = kwargs.pop('hypothesis_template', "{}") | |||
| self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| pairs = [[data, self.hypothesis_template.format(label)] for label in self.candidate_labels] | |||
| features = self.tokenizer( | |||
| pairs, | |||
| padding=True, | |||
| truncation=True, | |||
| max_length=self.sequence_length, | |||
| return_tensors='pt', | |||
| truncation_strategy='only_first' | |||
| ) | |||
| return features | |||
| @@ -30,6 +30,7 @@ class Tasks(object): | |||
| image_matting = 'image-matting' | |||
| # nlp tasks | |||
| zero_shot_classification = 'zero-shot-classification' | |||
| sentiment_analysis = 'sentiment-analysis' | |||
| text_classification = 'text-classification' | |||
| relation_extraction = 'relation-extraction' | |||
| @@ -0,0 +1,68 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from maas_hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import BertForZeroShotClassification | |||
| from modelscope.pipelines import ZeroShotClassificationPipeline, pipeline | |||
| from modelscope.preprocessors import ZeroShotClassificationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| class ZeroShotClassificationTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base' | |||
| sentence = '全新突破 解放军运20版空中加油机曝光' | |||
| candidate_labels = ["文化", "体育", "娱乐", "财经", "家居", "汽车", "教育", "科技", "军事"] | |||
| def test_run_from_local(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = ZeroShotClassificationPreprocessor(cache_path, candidate_labels=self.candidate_labels) | |||
| model = BertForZeroShotClassification( | |||
| cache_path, tokenizer=tokenizer) | |||
| pipeline1 = ZeroShotClassificationPipeline( | |||
| model, | |||
| preprocessor=tokenizer, | |||
| candidate_labels=self.candidate_labels, | |||
| ) | |||
| pipeline2 = pipeline( | |||
| Tasks.zero_shot_classification, | |||
| model=model, | |||
| preprocessor=tokenizer, | |||
| candidate_labels=self.candidate_labels | |||
| ) | |||
| print(f'sentence: {self.sentence}\n' | |||
| f'pipeline1:{pipeline1(input=self.sentence)}') | |||
| print() | |||
| print(f'sentence: {self.sentence}\n' | |||
| f'pipeline2: {pipeline2(input=self.sentence)}') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = ZeroShotClassificationPreprocessor(model.model_dir, candidate_labels=self.candidate_labels) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.zero_shot_classification, | |||
| model=model, | |||
| preprocessor=tokenizer, | |||
| candidate_labels=self.candidate_labels | |||
| ) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.zero_shot_classification, | |||
| model=self.model_id, | |||
| candidate_labels=self.candidate_labels | |||
| ) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.zero_shot_classification, | |||
| candidate_labels=self.candidate_labels) | |||
| print(pipeline_ins(input=self.sentence)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||