From 34db19131f68fa3cf284ad60fa7311523788474c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=93=E7=9B=B8?= Date: Thu, 16 Jun 2022 14:52:03 +0800 Subject: [PATCH] init --- modelscope/models/nlp/__init__.py | 1 + .../nlp/zero_shot_classification_model.py | 45 +++++++++++ modelscope/pipelines/builder.py | 2 + modelscope/pipelines/nlp/__init__.py | 1 + .../nlp/zero_shot_classification_pipeline.py | 78 +++++++++++++++++++ modelscope/preprocessors/__init__.py | 2 +- modelscope/preprocessors/nlp.py | 46 +++++++++++ modelscope/utils/constant.py | 1 + .../test_zero_shot_classification.py | 68 ++++++++++++++++ 9 files changed, 243 insertions(+), 1 deletion(-) create mode 100644 modelscope/models/nlp/zero_shot_classification_model.py create mode 100644 modelscope/pipelines/nlp/zero_shot_classification_pipeline.py create mode 100644 tests/pipelines/test_zero_shot_classification.py diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index b2a1d43b..37e6dd3c 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -1,2 +1,3 @@ from .sequence_classification_model import * # noqa F403 from .text_generation_model import * # noqa F403 +from .zero_shot_classification_model import * diff --git a/modelscope/models/nlp/zero_shot_classification_model.py b/modelscope/models/nlp/zero_shot_classification_model.py new file mode 100644 index 00000000..3f658dba --- /dev/null +++ b/modelscope/models/nlp/zero_shot_classification_model.py @@ -0,0 +1,45 @@ +from typing import Any, Dict +import torch +import numpy as np + +from modelscope.utils.constant import Tasks +from ..base import Model +from ..builder import MODELS + +__all__ = ['BertForZeroShotClassification'] + + +@MODELS.register_module( + Tasks.zero_shot_classification, module_name=r'bert-zero-shot-classification') +class BertForZeroShotClassification(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the zero shot classification model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + + super().__init__(model_dir, *args, **kwargs) + from sofa import SbertForSequenceClassification + self.model = SbertForSequenceClassification.from_pretrained(model_dir) + self.model.eval() + + def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: + """return the result by the model + + Args: + input (Dict[str, Any]): the preprocessed data + + Returns: + Dict[str, np.ndarray]: results + Example: + { + 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + } + """ + with torch.no_grad(): + outputs = self.model(**input) + logits = outputs["logits"].numpy() + res = {'logits': logits} + return res diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 6495a5db..5c5190c0 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -20,6 +20,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.image_matting: ('image-matting', 'damo/image-matting-person'), Tasks.text_classification: ('bert-sentiment-analysis', 'damo/bert-base-sst2'), + Tasks.zero_shot_classification: + ('bert-zero-shot-classification', 'damo/nlp_structbert_zero-shot-classification_chinese-base'), Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), Tasks.image_captioning: ('ofa', None), Tasks.image_generation: diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index 3dbbc1bb..02f4fbfa 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -1,2 +1,3 @@ from .sequence_classification_pipeline import * # noqa F403 from .text_generation_pipeline import * # noqa F403 +from .zero_shot_classification_pipeline import * diff --git a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py new file mode 100644 index 00000000..b557f3f0 --- /dev/null +++ b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py @@ -0,0 +1,78 @@ +import os +import uuid +from typing import Any, Dict, Union + +import json +import numpy as np + +from modelscope.models.nlp import BertForZeroShotClassification +from modelscope.preprocessors import ZeroShotClassificationPreprocessor +from modelscope.utils.constant import Tasks +from ...models import Model +from ..base import Input, Pipeline +from ..builder import PIPELINES +from scipy.special import softmax + +__all__ = ['ZeroShotClassificationPipeline'] + + +@PIPELINES.register_module( + Tasks.zero_shot_classification, + module_name=r'bert-zero-shot-classification') +class ZeroShotClassificationPipeline(Pipeline): + + def __init__(self, + model: Union[BertForZeroShotClassification, str], + preprocessor: ZeroShotClassificationPreprocessor = None, + **kwargs): + """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + + Args: + model (SbertForSentimentClassification): a model instance + preprocessor (SentimentClassificationPreprocessor): a preprocessor instance + """ + assert isinstance(model, str) or isinstance(model, BertForZeroShotClassification), \ + 'model must be a single str or BertForZeroShotClassification' + sc_model = model if isinstance( + model, + BertForZeroShotClassification) else Model.from_pretrained(model) + + self.entailment_id = 0 + self.contradiction_id = 2 + self.candidate_labels = kwargs.pop("candidate_labels") + self.hypothesis_template = kwargs.pop('hypothesis_template', "{}") + self.multi_label = kwargs.pop('multi_label', False) + + if preprocessor is None: + preprocessor = ZeroShotClassificationPreprocessor( + sc_model.model_dir, + candidate_labels=self.candidate_labels, + hypothesis_template=self.hypothesis_template + ) + super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, Any]: the prediction results + """ + + logits = inputs['logits'] + + if self.multi_label or len(self.candidate_labels) == 1: + logits = logits[..., [self.contradiction_id, self.entailment_id]] + scores = softmax(logits, axis=-1)[..., 1] + else: + logits = logits[..., self.entailment_id] + scores = softmax(logits, axis=-1) + + reversed_index = list(reversed(scores.argsort())) + result = { + "labels": [self.candidate_labels[i] for i in reversed_index], + "scores": [scores[i].item() for i in reversed_index], + } + return result diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 518ea977..b9a6901d 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -5,4 +5,4 @@ from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image from .nlp import * # noqa F403 -from .nlp import TextGenerationPreprocessor +from .nlp import TextGenerationPreprocessor, ZeroShotClassificationPreprocessor diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 0de41bfc..4ee3ee6a 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -147,3 +147,49 @@ class TextGenerationPreprocessor(Preprocessor): rst['token_type_ids'].append(feature['token_type_ids']) return {k: torch.tensor(v) for k, v in rst.items()} + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=r'bert-zero-shot-classification') +class ZeroShotClassificationPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + + super().__init__(*args, **kwargs) + + from sofa import SbertTokenizer + self.model_dir: str = model_dir + self.sequence_length = kwargs.pop('sequence_length', 512) + self.candidate_labels = kwargs.pop("candidate_labels") + self.hypothesis_template = kwargs.pop('hypothesis_template', "{}") + self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) + + @type_assert(object, str) + def __call__(self, data: str) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + pairs = [[data, self.hypothesis_template.format(label)] for label in self.candidate_labels] + + features = self.tokenizer( + pairs, + padding=True, + truncation=True, + max_length=self.sequence_length, + return_tensors='pt', + truncation_strategy='only_first' + ) + return features + diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index fa30dd2a..f1eb1fbd 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -30,6 +30,7 @@ class Tasks(object): image_matting = 'image-matting' # nlp tasks + zero_shot_classification = 'zero-shot-classification' sentiment_analysis = 'sentiment-analysis' text_classification = 'text-classification' relation_extraction = 'relation-extraction' diff --git a/tests/pipelines/test_zero_shot_classification.py b/tests/pipelines/test_zero_shot_classification.py new file mode 100644 index 00000000..f55324f0 --- /dev/null +++ b/tests/pipelines/test_zero_shot_classification.py @@ -0,0 +1,68 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from maas_hub.snapshot_download import snapshot_download + +from modelscope.models import Model +from modelscope.models.nlp import BertForZeroShotClassification +from modelscope.pipelines import ZeroShotClassificationPipeline, pipeline +from modelscope.preprocessors import ZeroShotClassificationPreprocessor +from modelscope.utils.constant import Tasks + + +class ZeroShotClassificationTest(unittest.TestCase): + model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base' + sentence = '全新突破 解放军运20版空中加油机曝光' + candidate_labels = ["文化", "体育", "娱乐", "财经", "家居", "汽车", "教育", "科技", "军事"] + + def test_run_from_local(self): + cache_path = snapshot_download(self.model_id) + tokenizer = ZeroShotClassificationPreprocessor(cache_path, candidate_labels=self.candidate_labels) + model = BertForZeroShotClassification( + cache_path, tokenizer=tokenizer) + pipeline1 = ZeroShotClassificationPipeline( + model, + preprocessor=tokenizer, + candidate_labels=self.candidate_labels, + ) + pipeline2 = pipeline( + Tasks.zero_shot_classification, + model=model, + preprocessor=tokenizer, + candidate_labels=self.candidate_labels + ) + + print(f'sentence: {self.sentence}\n' + f'pipeline1:{pipeline1(input=self.sentence)}') + print() + print(f'sentence: {self.sentence}\n' + f'pipeline2: {pipeline2(input=self.sentence)}') + + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = ZeroShotClassificationPreprocessor(model.model_dir, candidate_labels=self.candidate_labels) + pipeline_ins = pipeline( + task=Tasks.zero_shot_classification, + model=model, + preprocessor=tokenizer, + candidate_labels=self.candidate_labels + ) + print(pipeline_ins(input=self.sentence)) + + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.zero_shot_classification, + model=self.model_id, + candidate_labels=self.candidate_labels + ) + print(pipeline_ins(input=self.sentence)) + + def test_run_with_default_model(self): + pipeline_ins = pipeline( + task=Tasks.zero_shot_classification, + candidate_labels=self.candidate_labels) + print(pipeline_ins(input=self.sentence)) + + +if __name__ == '__main__': + unittest.main()