Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9180863master
| @@ -52,6 +52,7 @@ class Pipelines(object): | |||
| text_generation = 'text-generation' | |||
| sentiment_analysis = 'sentiment-analysis' | |||
| fill_mask = 'fill-mask' | |||
| zero_shot_classification = 'zero-shot-classification' | |||
| # audio tasks | |||
| sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' | |||
| @@ -95,6 +96,7 @@ class Preprocessors(object): | |||
| bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | |||
| palm_text_gen_tokenizer = 'palm-text-gen-tokenizer' | |||
| sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | |||
| zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer' | |||
| # audio preprocessor | |||
| linear_aec_fbank = 'linear-aec-fbank' | |||
| @@ -7,4 +7,5 @@ from .audio.tts.vocoder import Hifigan16k | |||
| from .base import Model | |||
| from .builder import MODELS, build_model | |||
| from .multi_modal import OfaForImageCaptioning | |||
| from .nlp import BertForSequenceClassification, SbertForSentenceSimilarity | |||
| from .nlp import (BertForSequenceClassification, SbertForSentenceSimilarity, | |||
| SbertForZeroShotClassification) | |||
| @@ -3,3 +3,4 @@ from .masked_language_model import * # noqa F403 | |||
| from .palm_for_text_generation import * # noqa F403 | |||
| from .sbert_for_sentence_similarity import * # noqa F403 | |||
| from .sbert_for_token_classification import * # noqa F403 | |||
| from .sbert_for_zero_shot_classification import * # noqa F403 | |||
| @@ -0,0 +1,50 @@ | |||
| from typing import Any, Dict | |||
| import numpy as np | |||
| from modelscope.utils.constant import Tasks | |||
| from ...metainfo import Models | |||
| from ..base import Model | |||
| from ..builder import MODELS | |||
| __all__ = ['SbertForZeroShotClassification'] | |||
| @MODELS.register_module( | |||
| Tasks.zero_shot_classification, module_name=Models.structbert) | |||
| class SbertForZeroShotClassification(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the zero shot classification model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| from sofa import SbertForSequenceClassification | |||
| self.model = SbertForSequenceClassification.from_pretrained(model_dir) | |||
| def train(self): | |||
| return self.model.train() | |||
| def eval(self): | |||
| return self.model.eval() | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Any]): the preprocessed data | |||
| Returns: | |||
| Dict[str, np.ndarray]: results | |||
| Example: | |||
| { | |||
| 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
| } | |||
| """ | |||
| outputs = self.model(**input) | |||
| logits = outputs['logits'].numpy() | |||
| res = {'logits': logits} | |||
| return res | |||
| @@ -74,33 +74,57 @@ class Pipeline(ABC): | |||
| self.preprocessor = preprocessor | |||
| def __call__(self, input: Union[Input, List[Input]], *args, | |||
| **post_kwargs) -> Union[Dict[str, Any], Generator]: | |||
| **kwargs) -> Union[Dict[str, Any], Generator]: | |||
| # model provider should leave it as it is | |||
| # modelscope library developer will handle this function | |||
| # simple showcase, need to support iterator type for both tensorflow and pytorch | |||
| # input_dict = self._handle_input(input) | |||
| # sanitize the parameters | |||
| preprocess_params, forward_params, postprocess_params = self._sanitize_parameters( | |||
| **kwargs) | |||
| kwargs['preprocess_params'] = preprocess_params | |||
| kwargs['forward_params'] = forward_params | |||
| kwargs['postprocess_params'] = postprocess_params | |||
| if isinstance(input, list): | |||
| output = [] | |||
| for ele in input: | |||
| output.append(self._process_single(ele, *args, **post_kwargs)) | |||
| output.append(self._process_single(ele, *args, **kwargs)) | |||
| elif isinstance(input, MsDataset): | |||
| return self._process_iterator(input, *args, **post_kwargs) | |||
| return self._process_iterator(input, *args, **kwargs) | |||
| else: | |||
| output = self._process_single(input, *args, **post_kwargs) | |||
| output = self._process_single(input, *args, **kwargs) | |||
| return output | |||
| def _process_iterator(self, input: Input, *args, **post_kwargs): | |||
| def _sanitize_parameters(self, **pipeline_parameters): | |||
| """ | |||
| this method should sanitize the keyword args to preprocessor params, | |||
| forward params and postprocess params on '__call__' or '_process_single' method | |||
| considered to be a normal classmethod with default implementation / output | |||
| Default Returns: | |||
| Dict[str, str]: preprocess_params = {} | |||
| Dict[str, str]: forward_params = {} | |||
| Dict[str, str]: postprocess_params = pipeline_parameters | |||
| """ | |||
| return {}, {}, pipeline_parameters | |||
| def _process_iterator(self, input: Input, *args, **kwargs): | |||
| for ele in input: | |||
| yield self._process_single(ele, *args, **post_kwargs) | |||
| yield self._process_single(ele, *args, **kwargs) | |||
| def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: | |||
| preprocess_params = kwargs.get('preprocess_params') | |||
| forward_params = kwargs.get('forward_params') | |||
| postprocess_params = kwargs.get('postprocess_params') | |||
| def _process_single(self, input: Input, *args, | |||
| **post_kwargs) -> Dict[str, Any]: | |||
| out = self.preprocess(input) | |||
| out = self.forward(out) | |||
| out = self.postprocess(out, **post_kwargs) | |||
| out = self.preprocess(input, **preprocess_params) | |||
| out = self.forward(out, **forward_params) | |||
| out = self.postprocess(out, **postprocess_params) | |||
| self._check_output(out) | |||
| return out | |||
| @@ -120,20 +144,21 @@ class Pipeline(ABC): | |||
| raise ValueError(f'expected output keys are {output_keys}, ' | |||
| f'those {missing_keys} are missing') | |||
| def preprocess(self, inputs: Input) -> Dict[str, Any]: | |||
| def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: | |||
| """ Provide default implementation based on preprocess_cfg and user can reimplement it | |||
| """ | |||
| assert self.preprocessor is not None, 'preprocess method should be implemented' | |||
| assert not isinstance(self.preprocessor, List),\ | |||
| 'default implementation does not support using multiple preprocessors.' | |||
| return self.preprocessor(inputs) | |||
| return self.preprocessor(inputs, **preprocess_params) | |||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| """ Provide default implementation using self.model and user can reimplement it | |||
| """ | |||
| assert self.model is not None, 'forward method should be implemented' | |||
| assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.' | |||
| return self.model(inputs) | |||
| return self.model(inputs, **forward_params) | |||
| @abstractmethod | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| @@ -27,6 +27,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/bert-base-sst2'), | |||
| Tasks.text_generation: (Pipelines.text_generation, | |||
| 'damo/nlp_palm2.0_text-generation_chinese-base'), | |||
| Tasks.zero_shot_classification: | |||
| (Pipelines.zero_shot_classification, | |||
| 'damo/nlp_structbert_zero-shot-classification_chinese-base'), | |||
| Tasks.image_captioning: (Pipelines.image_caption, | |||
| 'damo/ofa_image-caption_coco_large_en'), | |||
| Tasks.image_generation: | |||
| @@ -3,3 +3,4 @@ from .sentence_similarity_pipeline import * # noqa F403 | |||
| from .sequence_classification_pipeline import * # noqa F403 | |||
| from .text_generation_pipeline import * # noqa F403 | |||
| from .word_segmentation_pipeline import * # noqa F403 | |||
| from .zero_shot_classification_pipeline import * # noqa F403 | |||
| @@ -0,0 +1,97 @@ | |||
| import os | |||
| import uuid | |||
| from typing import Any, Dict, Union | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from scipy.special import softmax | |||
| from ...metainfo import Pipelines | |||
| from ...models import Model | |||
| from ...models.nlp import SbertForZeroShotClassification | |||
| from ...preprocessors import ZeroShotClassificationPreprocessor | |||
| from ...utils.constant import Tasks | |||
| from ..base import Input, Pipeline | |||
| from ..builder import PIPELINES | |||
| __all__ = ['ZeroShotClassificationPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.zero_shot_classification, | |||
| module_name=Pipelines.zero_shot_classification) | |||
| class ZeroShotClassificationPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[SbertForZeroShotClassification, str], | |||
| preprocessor: ZeroShotClassificationPreprocessor = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
| Args: | |||
| model (SbertForSentimentClassification): a model instance | |||
| preprocessor (SentimentClassificationPreprocessor): a preprocessor instance | |||
| """ | |||
| assert isinstance(model, str) or isinstance(model, SbertForZeroShotClassification), \ | |||
| 'model must be a single str or SbertForZeroShotClassification' | |||
| model = model if isinstance( | |||
| model, | |||
| SbertForZeroShotClassification) else Model.from_pretrained(model) | |||
| self.entailment_id = 0 | |||
| self.contradiction_id = 2 | |||
| if preprocessor is None: | |||
| preprocessor = ZeroShotClassificationPreprocessor(model.model_dir) | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| def _sanitize_parameters(self, **kwargs): | |||
| preprocess_params = {} | |||
| postprocess_params = {} | |||
| if 'candidate_labels' in kwargs: | |||
| candidate_labels = kwargs.pop('candidate_labels') | |||
| preprocess_params['candidate_labels'] = candidate_labels | |||
| postprocess_params['candidate_labels'] = candidate_labels | |||
| else: | |||
| raise ValueError('You must include at least one label.') | |||
| preprocess_params['hypothesis_template'] = kwargs.pop( | |||
| 'hypothesis_template', '{}') | |||
| postprocess_params['multi_label'] = kwargs.pop('multi_label', False) | |||
| return preprocess_params, {}, postprocess_params | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, | |||
| inputs: Dict[str, Any], | |||
| candidate_labels, | |||
| multi_label=False) -> Dict[str, Any]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, Any]: the prediction results | |||
| """ | |||
| logits = inputs['logits'] | |||
| if multi_label or len(candidate_labels) == 1: | |||
| logits = logits[..., [self.contradiction_id, self.entailment_id]] | |||
| scores = softmax(logits, axis=-1)[..., 1] | |||
| else: | |||
| logits = logits[..., self.entailment_id] | |||
| scores = softmax(logits, axis=-1) | |||
| reversed_index = list(reversed(scores.argsort())) | |||
| result = { | |||
| 'labels': [candidate_labels[i] for i in reversed_index], | |||
| 'scores': [scores[i].item() for i in reversed_index], | |||
| } | |||
| return result | |||
| @@ -101,6 +101,13 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.sentence_similarity: ['scores', 'labels'], | |||
| # zero-shot classification result for single sample | |||
| # { | |||
| # "labels": ["happy", "sad", "calm", "angry"], | |||
| # "scores": [0.9, 0.1, 0.05, 0.05] | |||
| # } | |||
| Tasks.zero_shot_classification: ['scores', 'labels'], | |||
| # ============ audio tasks =================== | |||
| # audio processed for single file in PCM format | |||
| @@ -14,7 +14,7 @@ from .builder import PREPROCESSORS | |||
| __all__ = [ | |||
| 'Tokenize', 'SequenceClassificationPreprocessor', | |||
| 'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor', | |||
| 'FillMaskPreprocessor' | |||
| 'FillMaskPreprocessor', 'ZeroShotClassificationPreprocessor' | |||
| ] | |||
| @@ -286,3 +286,47 @@ class TokenClassifcationPreprocessor(Preprocessor): | |||
| 'attention_mask': attention_mask, | |||
| 'token_type_ids': token_type_ids | |||
| } | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer) | |||
| class ZeroShotClassificationPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data via the vocab.txt from the `model_dir` path | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| from sofa import SbertTokenizer | |||
| self.model_dir: str = model_dir | |||
| self.sequence_length = kwargs.pop('sequence_length', 512) | |||
| self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) | |||
| @type_assert(object, str) | |||
| def __call__(self, data: str, hypothesis_template: str, | |||
| candidate_labels: list) -> Dict[str, Any]: | |||
| """process the raw input data | |||
| Args: | |||
| data (str): a sentence | |||
| Example: | |||
| 'you are so handsome.' | |||
| Returns: | |||
| Dict[str, Any]: the preprocessed data | |||
| """ | |||
| pairs = [[data, hypothesis_template.format(label)] | |||
| for label in candidate_labels] | |||
| features = self.tokenizer( | |||
| pairs, | |||
| padding=True, | |||
| truncation=True, | |||
| max_length=self.sequence_length, | |||
| return_tensors='pt', | |||
| truncation_strategy='only_first') | |||
| return features | |||
| @@ -48,6 +48,7 @@ class Tasks(object): | |||
| fill_mask = 'fill-mask' | |||
| summarization = 'summarization' | |||
| question_answering = 'question-answering' | |||
| zero_shot_classification = 'zero-shot-classification' | |||
| # audio tasks | |||
| auto_speech_recognition = 'auto-speech-recognition' | |||
| @@ -0,0 +1,64 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import SbertForZeroShotClassification | |||
| from modelscope.pipelines import ZeroShotClassificationPipeline, pipeline | |||
| from modelscope.preprocessors import ZeroShotClassificationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class ZeroShotClassificationTest(unittest.TestCase): | |||
| model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base' | |||
| sentence = '全新突破 解放军运20版空中加油机曝光' | |||
| labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'] | |||
| template = '这篇文章的标题是{}' | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_direct_file_download(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| tokenizer = ZeroShotClassificationPreprocessor(cache_path) | |||
| model = SbertForZeroShotClassification(cache_path, tokenizer=tokenizer) | |||
| pipeline1 = ZeroShotClassificationPipeline( | |||
| model, preprocessor=tokenizer) | |||
| pipeline2 = pipeline( | |||
| Tasks.zero_shot_classification, | |||
| model=model, | |||
| preprocessor=tokenizer) | |||
| print( | |||
| f'sentence: {self.sentence}\n' | |||
| f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}' | |||
| ) | |||
| print() | |||
| print( | |||
| f'sentence: {self.sentence}\n' | |||
| f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}' | |||
| ) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| tokenizer = ZeroShotClassificationPreprocessor(model.model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.zero_shot_classification, | |||
| model=model, | |||
| preprocessor=tokenizer) | |||
| print(pipeline_ins(input=self.sentence, candidate_labels=self.labels)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.zero_shot_classification, model=self.model_id) | |||
| print(pipeline_ins(input=self.sentence, candidate_labels=self.labels)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.zero_shot_classification) | |||
| print(pipeline_ins(input=self.sentence, candidate_labels=self.labels)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||