From 576b7cffb11532c3431fbfc2998ae833408c327b Mon Sep 17 00:00:00 2001 From: "zhangzhicheng.zzc" Date: Wed, 29 Jun 2022 09:12:59 +0800 Subject: [PATCH] [to #42322933] add pipeline params for preprocess and forward & zeroshot classification Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9180863 --- modelscope/metainfo.py | 2 + modelscope/models/__init__.py | 3 +- modelscope/models/nlp/__init__.py | 1 + .../nlp/sbert_for_zero_shot_classification.py | 50 ++++++++++ modelscope/pipelines/base.py | 55 ++++++++--- modelscope/pipelines/builder.py | 3 + modelscope/pipelines/nlp/__init__.py | 1 + .../nlp/zero_shot_classification_pipeline.py | 97 +++++++++++++++++++ modelscope/pipelines/outputs.py | 7 ++ modelscope/preprocessors/nlp.py | 46 ++++++++- modelscope/utils/constant.py | 1 + .../test_zero_shot_classification.py | 64 ++++++++++++ 12 files changed, 313 insertions(+), 17 deletions(-) create mode 100644 modelscope/models/nlp/sbert_for_zero_shot_classification.py create mode 100644 modelscope/pipelines/nlp/zero_shot_classification_pipeline.py create mode 100644 tests/pipelines/test_zero_shot_classification.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index eda590ac..1d2ee4d2 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -52,6 +52,7 @@ class Pipelines(object): text_generation = 'text-generation' sentiment_analysis = 'sentiment-analysis' fill_mask = 'fill-mask' + zero_shot_classification = 'zero-shot-classification' # audio tasks sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' @@ -95,6 +96,7 @@ class Preprocessors(object): bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' palm_text_gen_tokenizer = 'palm-text-gen-tokenizer' sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' + zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' # audio preprocessor linear_aec_fbank = 'linear-aec-fbank' diff --git a/modelscope/models/__init__.py b/modelscope/models/__init__.py index 816c44e2..f1074f68 100644 --- a/modelscope/models/__init__.py +++ b/modelscope/models/__init__.py @@ -7,4 +7,5 @@ from .audio.tts.vocoder import Hifigan16k from .base import Model from .builder import MODELS, build_model from .multi_modal import OfaForImageCaptioning -from .nlp import BertForSequenceClassification, SbertForSentenceSimilarity +from .nlp import (BertForSequenceClassification, SbertForSentenceSimilarity, + SbertForZeroShotClassification) diff --git a/modelscope/models/nlp/__init__.py b/modelscope/models/nlp/__init__.py index 6be4493b..f904efdf 100644 --- a/modelscope/models/nlp/__init__.py +++ b/modelscope/models/nlp/__init__.py @@ -3,3 +3,4 @@ from .masked_language_model import * # noqa F403 from .palm_for_text_generation import * # noqa F403 from .sbert_for_sentence_similarity import * # noqa F403 from .sbert_for_token_classification import * # noqa F403 +from .sbert_for_zero_shot_classification import * # noqa F403 diff --git a/modelscope/models/nlp/sbert_for_zero_shot_classification.py b/modelscope/models/nlp/sbert_for_zero_shot_classification.py new file mode 100644 index 00000000..837bb41e --- /dev/null +++ b/modelscope/models/nlp/sbert_for_zero_shot_classification.py @@ -0,0 +1,50 @@ +from typing import Any, Dict + +import numpy as np + +from modelscope.utils.constant import Tasks +from ...metainfo import Models +from ..base import Model +from ..builder import MODELS + +__all__ = ['SbertForZeroShotClassification'] + + +@MODELS.register_module( + Tasks.zero_shot_classification, module_name=Models.structbert) +class SbertForZeroShotClassification(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the zero shot classification model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + + super().__init__(model_dir, *args, **kwargs) + from sofa import SbertForSequenceClassification + self.model = SbertForSequenceClassification.from_pretrained(model_dir) + + def train(self): + return self.model.train() + + def eval(self): + return self.model.eval() + + def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: + """return the result by the model + + Args: + input (Dict[str, Any]): the preprocessed data + + Returns: + Dict[str, np.ndarray]: results + Example: + { + 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value + } + """ + outputs = self.model(**input) + logits = outputs['logits'].numpy() + res = {'logits': logits} + return res diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 2f5d5dcc..4052d35a 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -74,33 +74,57 @@ class Pipeline(ABC): self.preprocessor = preprocessor def __call__(self, input: Union[Input, List[Input]], *args, - **post_kwargs) -> Union[Dict[str, Any], Generator]: + **kwargs) -> Union[Dict[str, Any], Generator]: # model provider should leave it as it is # modelscope library developer will handle this function # simple showcase, need to support iterator type for both tensorflow and pytorch # input_dict = self._handle_input(input) + + # sanitize the parameters + preprocess_params, forward_params, postprocess_params = self._sanitize_parameters( + **kwargs) + kwargs['preprocess_params'] = preprocess_params + kwargs['forward_params'] = forward_params + kwargs['postprocess_params'] = postprocess_params + if isinstance(input, list): output = [] for ele in input: - output.append(self._process_single(ele, *args, **post_kwargs)) + output.append(self._process_single(ele, *args, **kwargs)) elif isinstance(input, MsDataset): - return self._process_iterator(input, *args, **post_kwargs) + return self._process_iterator(input, *args, **kwargs) else: - output = self._process_single(input, *args, **post_kwargs) + output = self._process_single(input, *args, **kwargs) return output - def _process_iterator(self, input: Input, *args, **post_kwargs): + def _sanitize_parameters(self, **pipeline_parameters): + """ + this method should sanitize the keyword args to preprocessor params, + forward params and postprocess params on '__call__' or '_process_single' method + considered to be a normal classmethod with default implementation / output + + Default Returns: + Dict[str, str]: preprocess_params = {} + Dict[str, str]: forward_params = {} + Dict[str, str]: postprocess_params = pipeline_parameters + """ + return {}, {}, pipeline_parameters + + def _process_iterator(self, input: Input, *args, **kwargs): for ele in input: - yield self._process_single(ele, *args, **post_kwargs) + yield self._process_single(ele, *args, **kwargs) + + def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: + preprocess_params = kwargs.get('preprocess_params') + forward_params = kwargs.get('forward_params') + postprocess_params = kwargs.get('postprocess_params') - def _process_single(self, input: Input, *args, - **post_kwargs) -> Dict[str, Any]: - out = self.preprocess(input) - out = self.forward(out) - out = self.postprocess(out, **post_kwargs) + out = self.preprocess(input, **preprocess_params) + out = self.forward(out, **forward_params) + out = self.postprocess(out, **postprocess_params) self._check_output(out) return out @@ -120,20 +144,21 @@ class Pipeline(ABC): raise ValueError(f'expected output keys are {output_keys}, ' f'those {missing_keys} are missing') - def preprocess(self, inputs: Input) -> Dict[str, Any]: + def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: """ Provide default implementation based on preprocess_cfg and user can reimplement it """ assert self.preprocessor is not None, 'preprocess method should be implemented' assert not isinstance(self.preprocessor, List),\ 'default implementation does not support using multiple preprocessors.' - return self.preprocessor(inputs) + return self.preprocessor(inputs, **preprocess_params) - def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: """ Provide default implementation using self.model and user can reimplement it """ assert self.model is not None, 'forward method should be implemented' assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.' - return self.model(inputs) + return self.model(inputs, **forward_params) @abstractmethod def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 41cd73da..847955d4 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -27,6 +27,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/bert-base-sst2'), Tasks.text_generation: (Pipelines.text_generation, 'damo/nlp_palm2.0_text-generation_chinese-base'), + Tasks.zero_shot_classification: + (Pipelines.zero_shot_classification, + 'damo/nlp_structbert_zero-shot-classification_chinese-base'), Tasks.image_captioning: (Pipelines.image_caption, 'damo/ofa_image-caption_coco_large_en'), Tasks.image_generation: diff --git a/modelscope/pipelines/nlp/__init__.py b/modelscope/pipelines/nlp/__init__.py index c50875fd..5ef12e22 100644 --- a/modelscope/pipelines/nlp/__init__.py +++ b/modelscope/pipelines/nlp/__init__.py @@ -3,3 +3,4 @@ from .sentence_similarity_pipeline import * # noqa F403 from .sequence_classification_pipeline import * # noqa F403 from .text_generation_pipeline import * # noqa F403 from .word_segmentation_pipeline import * # noqa F403 +from .zero_shot_classification_pipeline import * # noqa F403 diff --git a/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py new file mode 100644 index 00000000..2ed4dac3 --- /dev/null +++ b/modelscope/pipelines/nlp/zero_shot_classification_pipeline.py @@ -0,0 +1,97 @@ +import os +import uuid +from typing import Any, Dict, Union + +import json +import numpy as np +import torch +from scipy.special import softmax + +from ...metainfo import Pipelines +from ...models import Model +from ...models.nlp import SbertForZeroShotClassification +from ...preprocessors import ZeroShotClassificationPreprocessor +from ...utils.constant import Tasks +from ..base import Input, Pipeline +from ..builder import PIPELINES + +__all__ = ['ZeroShotClassificationPipeline'] + + +@PIPELINES.register_module( + Tasks.zero_shot_classification, + module_name=Pipelines.zero_shot_classification) +class ZeroShotClassificationPipeline(Pipeline): + + def __init__(self, + model: Union[SbertForZeroShotClassification, str], + preprocessor: ZeroShotClassificationPreprocessor = None, + **kwargs): + """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + + Args: + model (SbertForSentimentClassification): a model instance + preprocessor (SentimentClassificationPreprocessor): a preprocessor instance + """ + assert isinstance(model, str) or isinstance(model, SbertForZeroShotClassification), \ + 'model must be a single str or SbertForZeroShotClassification' + model = model if isinstance( + model, + SbertForZeroShotClassification) else Model.from_pretrained(model) + + self.entailment_id = 0 + self.contradiction_id = 2 + + if preprocessor is None: + preprocessor = ZeroShotClassificationPreprocessor(model.model_dir) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def _sanitize_parameters(self, **kwargs): + preprocess_params = {} + postprocess_params = {} + + if 'candidate_labels' in kwargs: + candidate_labels = kwargs.pop('candidate_labels') + preprocess_params['candidate_labels'] = candidate_labels + postprocess_params['candidate_labels'] = candidate_labels + else: + raise ValueError('You must include at least one label.') + preprocess_params['hypothesis_template'] = kwargs.pop( + 'hypothesis_template', '{}') + + postprocess_params['multi_label'] = kwargs.pop('multi_label', False) + return preprocess_params, {}, postprocess_params + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return super().forward(inputs, **forward_params) + + def postprocess(self, + inputs: Dict[str, Any], + candidate_labels, + multi_label=False) -> Dict[str, Any]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, Any]: the prediction results + """ + + logits = inputs['logits'] + if multi_label or len(candidate_labels) == 1: + logits = logits[..., [self.contradiction_id, self.entailment_id]] + scores = softmax(logits, axis=-1)[..., 1] + else: + logits = logits[..., self.entailment_id] + scores = softmax(logits, axis=-1) + + reversed_index = list(reversed(scores.argsort())) + result = { + 'labels': [candidate_labels[i] for i in reversed_index], + 'scores': [scores[i].item() for i in reversed_index], + } + return result diff --git a/modelscope/pipelines/outputs.py b/modelscope/pipelines/outputs.py index 52b7eeae..290e6717 100644 --- a/modelscope/pipelines/outputs.py +++ b/modelscope/pipelines/outputs.py @@ -101,6 +101,13 @@ TASK_OUTPUTS = { # } Tasks.sentence_similarity: ['scores', 'labels'], + # zero-shot classification result for single sample + # { + # "labels": ["happy", "sad", "calm", "angry"], + # "scores": [0.9, 0.1, 0.05, 0.05] + # } + Tasks.zero_shot_classification: ['scores', 'labels'], + # ============ audio tasks =================== # audio processed for single file in PCM format diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 4ed63f3c..e8e33e74 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -14,7 +14,7 @@ from .builder import PREPROCESSORS __all__ = [ 'Tokenize', 'SequenceClassificationPreprocessor', 'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor', - 'FillMaskPreprocessor' + 'FillMaskPreprocessor', 'ZeroShotClassificationPreprocessor' ] @@ -286,3 +286,47 @@ class TokenClassifcationPreprocessor(Preprocessor): 'attention_mask': attention_mask, 'token_type_ids': token_type_ids } + + +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer) +class ZeroShotClassificationPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via the vocab.txt from the `model_dir` path + + Args: + model_dir (str): model path + """ + + super().__init__(*args, **kwargs) + + from sofa import SbertTokenizer + self.model_dir: str = model_dir + self.sequence_length = kwargs.pop('sequence_length', 512) + self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) + + @type_assert(object, str) + def __call__(self, data: str, hypothesis_template: str, + candidate_labels: list) -> Dict[str, Any]: + """process the raw input data + + Args: + data (str): a sentence + Example: + 'you are so handsome.' + + Returns: + Dict[str, Any]: the preprocessed data + """ + pairs = [[data, hypothesis_template.format(label)] + for label in candidate_labels] + + features = self.tokenizer( + pairs, + padding=True, + truncation=True, + max_length=self.sequence_length, + return_tensors='pt', + truncation_strategy='only_first') + return features diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 55f015e8..44bd1dff 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -48,6 +48,7 @@ class Tasks(object): fill_mask = 'fill-mask' summarization = 'summarization' question_answering = 'question-answering' + zero_shot_classification = 'zero-shot-classification' # audio tasks auto_speech_recognition = 'auto-speech-recognition' diff --git a/tests/pipelines/test_zero_shot_classification.py b/tests/pipelines/test_zero_shot_classification.py new file mode 100644 index 00000000..b76a6a86 --- /dev/null +++ b/tests/pipelines/test_zero_shot_classification.py @@ -0,0 +1,64 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.nlp import SbertForZeroShotClassification +from modelscope.pipelines import ZeroShotClassificationPipeline, pipeline +from modelscope.preprocessors import ZeroShotClassificationPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class ZeroShotClassificationTest(unittest.TestCase): + model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base' + sentence = '全新突破 解放军运20版空中加油机曝光' + labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'] + template = '这篇文章的标题是{}' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_direct_file_download(self): + cache_path = snapshot_download(self.model_id) + tokenizer = ZeroShotClassificationPreprocessor(cache_path) + model = SbertForZeroShotClassification(cache_path, tokenizer=tokenizer) + pipeline1 = ZeroShotClassificationPipeline( + model, preprocessor=tokenizer) + pipeline2 = pipeline( + Tasks.zero_shot_classification, + model=model, + preprocessor=tokenizer) + + print( + f'sentence: {self.sentence}\n' + f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}' + ) + print() + print( + f'sentence: {self.sentence}\n' + f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}' + ) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + tokenizer = ZeroShotClassificationPreprocessor(model.model_dir) + pipeline_ins = pipeline( + task=Tasks.zero_shot_classification, + model=model, + preprocessor=tokenizer) + print(pipeline_ins(input=self.sentence, candidate_labels=self.labels)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_ins = pipeline( + task=Tasks.zero_shot_classification, model=self.model_id) + print(pipeline_ins(input=self.sentence, candidate_labels=self.labels)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.zero_shot_classification) + print(pipeline_ins(input=self.sentence, candidate_labels=self.labels)) + + +if __name__ == '__main__': + unittest.main()