Browse Source

[to #42322933] add pipeline params for preprocess and forward & zeroshot classification

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9180863
master
zhangzhicheng.zzc yingda.chen 3 years ago
parent
commit
576b7cffb1
12 changed files with 313 additions and 17 deletions
  1. +2
    -0
      modelscope/metainfo.py
  2. +2
    -1
      modelscope/models/__init__.py
  3. +1
    -0
      modelscope/models/nlp/__init__.py
  4. +50
    -0
      modelscope/models/nlp/sbert_for_zero_shot_classification.py
  5. +40
    -15
      modelscope/pipelines/base.py
  6. +3
    -0
      modelscope/pipelines/builder.py
  7. +1
    -0
      modelscope/pipelines/nlp/__init__.py
  8. +97
    -0
      modelscope/pipelines/nlp/zero_shot_classification_pipeline.py
  9. +7
    -0
      modelscope/pipelines/outputs.py
  10. +45
    -1
      modelscope/preprocessors/nlp.py
  11. +1
    -0
      modelscope/utils/constant.py
  12. +64
    -0
      tests/pipelines/test_zero_shot_classification.py

+ 2
- 0
modelscope/metainfo.py View File

@@ -52,6 +52,7 @@ class Pipelines(object):
text_generation = 'text-generation'
sentiment_analysis = 'sentiment-analysis'
fill_mask = 'fill-mask'
zero_shot_classification = 'zero-shot-classification'

# audio tasks
sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts'
@@ -95,6 +96,7 @@ class Preprocessors(object):
bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer'
palm_text_gen_tokenizer = 'palm-text-gen-tokenizer'
sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer'
zero_shot_cls_tokenizer = 'zero-shot-cls-tokenizer'

# audio preprocessor
linear_aec_fbank = 'linear-aec-fbank'


+ 2
- 1
modelscope/models/__init__.py View File

@@ -7,4 +7,5 @@ from .audio.tts.vocoder import Hifigan16k
from .base import Model
from .builder import MODELS, build_model
from .multi_modal import OfaForImageCaptioning
from .nlp import BertForSequenceClassification, SbertForSentenceSimilarity
from .nlp import (BertForSequenceClassification, SbertForSentenceSimilarity,
SbertForZeroShotClassification)

+ 1
- 0
modelscope/models/nlp/__init__.py View File

@@ -3,3 +3,4 @@ from .masked_language_model import * # noqa F403
from .palm_for_text_generation import * # noqa F403
from .sbert_for_sentence_similarity import * # noqa F403
from .sbert_for_token_classification import * # noqa F403
from .sbert_for_zero_shot_classification import * # noqa F403

+ 50
- 0
modelscope/models/nlp/sbert_for_zero_shot_classification.py View File

@@ -0,0 +1,50 @@
from typing import Any, Dict

import numpy as np

from modelscope.utils.constant import Tasks
from ...metainfo import Models
from ..base import Model
from ..builder import MODELS

__all__ = ['SbertForZeroShotClassification']


@MODELS.register_module(
Tasks.zero_shot_classification, module_name=Models.structbert)
class SbertForZeroShotClassification(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the zero shot classification model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""

super().__init__(model_dir, *args, **kwargs)
from sofa import SbertForSequenceClassification
self.model = SbertForSequenceClassification.from_pretrained(model_dir)

def train(self):
return self.model.train()

def eval(self):
return self.model.eval()

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Example:
{
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
"""
outputs = self.model(**input)
logits = outputs['logits'].numpy()
res = {'logits': logits}
return res

+ 40
- 15
modelscope/pipelines/base.py View File

@@ -74,33 +74,57 @@ class Pipeline(ABC):
self.preprocessor = preprocessor

def __call__(self, input: Union[Input, List[Input]], *args,
**post_kwargs) -> Union[Dict[str, Any], Generator]:
**kwargs) -> Union[Dict[str, Any], Generator]:
# model provider should leave it as it is
# modelscope library developer will handle this function

# simple showcase, need to support iterator type for both tensorflow and pytorch
# input_dict = self._handle_input(input)

# sanitize the parameters
preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(
**kwargs)
kwargs['preprocess_params'] = preprocess_params
kwargs['forward_params'] = forward_params
kwargs['postprocess_params'] = postprocess_params

if isinstance(input, list):
output = []
for ele in input:
output.append(self._process_single(ele, *args, **post_kwargs))
output.append(self._process_single(ele, *args, **kwargs))

elif isinstance(input, MsDataset):
return self._process_iterator(input, *args, **post_kwargs)
return self._process_iterator(input, *args, **kwargs)

else:
output = self._process_single(input, *args, **post_kwargs)
output = self._process_single(input, *args, **kwargs)
return output

def _process_iterator(self, input: Input, *args, **post_kwargs):
def _sanitize_parameters(self, **pipeline_parameters):
"""
this method should sanitize the keyword args to preprocessor params,
forward params and postprocess params on '__call__' or '_process_single' method
considered to be a normal classmethod with default implementation / output

Default Returns:
Dict[str, str]: preprocess_params = {}
Dict[str, str]: forward_params = {}
Dict[str, str]: postprocess_params = pipeline_parameters
"""
return {}, {}, pipeline_parameters

def _process_iterator(self, input: Input, *args, **kwargs):
for ele in input:
yield self._process_single(ele, *args, **post_kwargs)
yield self._process_single(ele, *args, **kwargs)

def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]:
preprocess_params = kwargs.get('preprocess_params')
forward_params = kwargs.get('forward_params')
postprocess_params = kwargs.get('postprocess_params')

def _process_single(self, input: Input, *args,
**post_kwargs) -> Dict[str, Any]:
out = self.preprocess(input)
out = self.forward(out)
out = self.postprocess(out, **post_kwargs)
out = self.preprocess(input, **preprocess_params)
out = self.forward(out, **forward_params)
out = self.postprocess(out, **postprocess_params)
self._check_output(out)
return out

@@ -120,20 +144,21 @@ class Pipeline(ABC):
raise ValueError(f'expected output keys are {output_keys}, '
f'those {missing_keys} are missing')

def preprocess(self, inputs: Input) -> Dict[str, Any]:
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
""" Provide default implementation based on preprocess_cfg and user can reimplement it
"""
assert self.preprocessor is not None, 'preprocess method should be implemented'
assert not isinstance(self.preprocessor, List),\
'default implementation does not support using multiple preprocessors.'
return self.preprocessor(inputs)
return self.preprocessor(inputs, **preprocess_params)

def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
""" Provide default implementation using self.model and user can reimplement it
"""
assert self.model is not None, 'forward method should be implemented'
assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.'
return self.model(inputs)
return self.model(inputs, **forward_params)

@abstractmethod
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:


+ 3
- 0
modelscope/pipelines/builder.py View File

@@ -27,6 +27,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/bert-base-sst2'),
Tasks.text_generation: (Pipelines.text_generation,
'damo/nlp_palm2.0_text-generation_chinese-base'),
Tasks.zero_shot_classification:
(Pipelines.zero_shot_classification,
'damo/nlp_structbert_zero-shot-classification_chinese-base'),
Tasks.image_captioning: (Pipelines.image_caption,
'damo/ofa_image-caption_coco_large_en'),
Tasks.image_generation:


+ 1
- 0
modelscope/pipelines/nlp/__init__.py View File

@@ -3,3 +3,4 @@ from .sentence_similarity_pipeline import * # noqa F403
from .sequence_classification_pipeline import * # noqa F403
from .text_generation_pipeline import * # noqa F403
from .word_segmentation_pipeline import * # noqa F403
from .zero_shot_classification_pipeline import * # noqa F403

+ 97
- 0
modelscope/pipelines/nlp/zero_shot_classification_pipeline.py View File

@@ -0,0 +1,97 @@
import os
import uuid
from typing import Any, Dict, Union

import json
import numpy as np
import torch
from scipy.special import softmax

from ...metainfo import Pipelines
from ...models import Model
from ...models.nlp import SbertForZeroShotClassification
from ...preprocessors import ZeroShotClassificationPreprocessor
from ...utils.constant import Tasks
from ..base import Input, Pipeline
from ..builder import PIPELINES

__all__ = ['ZeroShotClassificationPipeline']


@PIPELINES.register_module(
Tasks.zero_shot_classification,
module_name=Pipelines.zero_shot_classification)
class ZeroShotClassificationPipeline(Pipeline):

def __init__(self,
model: Union[SbertForZeroShotClassification, str],
preprocessor: ZeroShotClassificationPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction

Args:
model (SbertForSentimentClassification): a model instance
preprocessor (SentimentClassificationPreprocessor): a preprocessor instance
"""
assert isinstance(model, str) or isinstance(model, SbertForZeroShotClassification), \
'model must be a single str or SbertForZeroShotClassification'
model = model if isinstance(
model,
SbertForZeroShotClassification) else Model.from_pretrained(model)

self.entailment_id = 0
self.contradiction_id = 2

if preprocessor is None:
preprocessor = ZeroShotClassificationPreprocessor(model.model_dir)
model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

def _sanitize_parameters(self, **kwargs):
preprocess_params = {}
postprocess_params = {}

if 'candidate_labels' in kwargs:
candidate_labels = kwargs.pop('candidate_labels')
preprocess_params['candidate_labels'] = candidate_labels
postprocess_params['candidate_labels'] = candidate_labels
else:
raise ValueError('You must include at least one label.')
preprocess_params['hypothesis_template'] = kwargs.pop(
'hypothesis_template', '{}')

postprocess_params['multi_label'] = kwargs.pop('multi_label', False)
return preprocess_params, {}, postprocess_params

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return super().forward(inputs, **forward_params)

def postprocess(self,
inputs: Dict[str, Any],
candidate_labels,
multi_label=False) -> Dict[str, Any]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, Any]: the prediction results
"""

logits = inputs['logits']
if multi_label or len(candidate_labels) == 1:
logits = logits[..., [self.contradiction_id, self.entailment_id]]
scores = softmax(logits, axis=-1)[..., 1]
else:
logits = logits[..., self.entailment_id]
scores = softmax(logits, axis=-1)

reversed_index = list(reversed(scores.argsort()))
result = {
'labels': [candidate_labels[i] for i in reversed_index],
'scores': [scores[i].item() for i in reversed_index],
}
return result

+ 7
- 0
modelscope/pipelines/outputs.py View File

@@ -101,6 +101,13 @@ TASK_OUTPUTS = {
# }
Tasks.sentence_similarity: ['scores', 'labels'],

# zero-shot classification result for single sample
# {
# "labels": ["happy", "sad", "calm", "angry"],
# "scores": [0.9, 0.1, 0.05, 0.05]
# }
Tasks.zero_shot_classification: ['scores', 'labels'],

# ============ audio tasks ===================

# audio processed for single file in PCM format


+ 45
- 1
modelscope/preprocessors/nlp.py View File

@@ -14,7 +14,7 @@ from .builder import PREPROCESSORS
__all__ = [
'Tokenize', 'SequenceClassificationPreprocessor',
'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor',
'FillMaskPreprocessor'
'FillMaskPreprocessor', 'ZeroShotClassificationPreprocessor'
]


@@ -286,3 +286,47 @@ class TokenClassifcationPreprocessor(Preprocessor):
'attention_mask': attention_mask,
'token_type_ids': token_type_ids
}


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer)
class ZeroShotClassificationPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""

super().__init__(*args, **kwargs)

from sofa import SbertTokenizer
self.model_dir: str = model_dir
self.sequence_length = kwargs.pop('sequence_length', 512)
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir)

@type_assert(object, str)
def __call__(self, data: str, hypothesis_template: str,
candidate_labels: list) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str): a sentence
Example:
'you are so handsome.'

Returns:
Dict[str, Any]: the preprocessed data
"""
pairs = [[data, hypothesis_template.format(label)]
for label in candidate_labels]

features = self.tokenizer(
pairs,
padding=True,
truncation=True,
max_length=self.sequence_length,
return_tensors='pt',
truncation_strategy='only_first')
return features

+ 1
- 0
modelscope/utils/constant.py View File

@@ -48,6 +48,7 @@ class Tasks(object):
fill_mask = 'fill-mask'
summarization = 'summarization'
question_answering = 'question-answering'
zero_shot_classification = 'zero-shot-classification'

# audio tasks
auto_speech_recognition = 'auto-speech-recognition'


+ 64
- 0
tests/pipelines/test_zero_shot_classification.py View File

@@ -0,0 +1,64 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SbertForZeroShotClassification
from modelscope.pipelines import ZeroShotClassificationPipeline, pipeline
from modelscope.preprocessors import ZeroShotClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class ZeroShotClassificationTest(unittest.TestCase):
model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base'
sentence = '全新突破 解放军运20版空中加油机曝光'
labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']
template = '这篇文章的标题是{}'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_direct_file_download(self):
cache_path = snapshot_download(self.model_id)
tokenizer = ZeroShotClassificationPreprocessor(cache_path)
model = SbertForZeroShotClassification(cache_path, tokenizer=tokenizer)
pipeline1 = ZeroShotClassificationPipeline(
model, preprocessor=tokenizer)
pipeline2 = pipeline(
Tasks.zero_shot_classification,
model=model,
preprocessor=tokenizer)

print(
f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence,candidate_labels=self.labels)}'
)
print()
print(
f'sentence: {self.sentence}\n'
f'pipeline2: {pipeline2(self.sentence,candidate_labels=self.labels,hypothesis_template=self.template)}'
)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = ZeroShotClassificationPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.zero_shot_classification,
model=model,
preprocessor=tokenizer)
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.zero_shot_classification, model=self.model_id)
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.zero_shot_classification)
print(pipeline_ins(input=self.sentence, candidate_labels=self.labels))


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save