Browse Source

init

master
易相 3 years ago
parent
commit
34db19131f
9 changed files with 243 additions and 1 deletions
  1. +1
    -0
      modelscope/models/nlp/__init__.py
  2. +45
    -0
      modelscope/models/nlp/zero_shot_classification_model.py
  3. +2
    -0
      modelscope/pipelines/builder.py
  4. +1
    -0
      modelscope/pipelines/nlp/__init__.py
  5. +78
    -0
      modelscope/pipelines/nlp/zero_shot_classification_pipeline.py
  6. +1
    -1
      modelscope/preprocessors/__init__.py
  7. +46
    -0
      modelscope/preprocessors/nlp.py
  8. +1
    -0
      modelscope/utils/constant.py
  9. +68
    -0
      tests/pipelines/test_zero_shot_classification.py

+ 1
- 0
modelscope/models/nlp/__init__.py View File

@@ -1,2 +1,3 @@
from .sequence_classification_model import * # noqa F403
from .text_generation_model import * # noqa F403
from .zero_shot_classification_model import *

+ 45
- 0
modelscope/models/nlp/zero_shot_classification_model.py View File

@@ -0,0 +1,45 @@
from typing import Any, Dict
import torch
import numpy as np

from modelscope.utils.constant import Tasks
from ..base import Model
from ..builder import MODELS

__all__ = ['BertForZeroShotClassification']


@MODELS.register_module(
Tasks.zero_shot_classification, module_name=r'bert-zero-shot-classification')
class BertForZeroShotClassification(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the zero shot classification model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""

super().__init__(model_dir, *args, **kwargs)
from sofa import SbertForSequenceClassification
self.model = SbertForSequenceClassification.from_pretrained(model_dir)
self.model.eval()

def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Example:
{
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
"""
with torch.no_grad():
outputs = self.model(**input)
logits = outputs["logits"].numpy()
res = {'logits': logits}
return res

+ 2
- 0
modelscope/pipelines/builder.py View File

@@ -20,6 +20,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.image_matting: ('image-matting', 'damo/image-matting-person'),
Tasks.text_classification:
('bert-sentiment-analysis', 'damo/bert-base-sst2'),
Tasks.zero_shot_classification:
('bert-zero-shot-classification', 'damo/nlp_structbert_zero-shot-classification_chinese-base'),
Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'),
Tasks.image_captioning: ('ofa', None),
Tasks.image_generation:


+ 1
- 0
modelscope/pipelines/nlp/__init__.py View File

@@ -1,2 +1,3 @@
from .sequence_classification_pipeline import * # noqa F403
from .text_generation_pipeline import * # noqa F403
from .zero_shot_classification_pipeline import *

+ 78
- 0
modelscope/pipelines/nlp/zero_shot_classification_pipeline.py View File

@@ -0,0 +1,78 @@
import os
import uuid
from typing import Any, Dict, Union

import json
import numpy as np

from modelscope.models.nlp import BertForZeroShotClassification
from modelscope.preprocessors import ZeroShotClassificationPreprocessor
from modelscope.utils.constant import Tasks
from ...models import Model
from ..base import Input, Pipeline
from ..builder import PIPELINES
from scipy.special import softmax

__all__ = ['ZeroShotClassificationPipeline']


@PIPELINES.register_module(
Tasks.zero_shot_classification,
module_name=r'bert-zero-shot-classification')
class ZeroShotClassificationPipeline(Pipeline):

def __init__(self,
model: Union[BertForZeroShotClassification, str],
preprocessor: ZeroShotClassificationPreprocessor = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction

Args:
model (SbertForSentimentClassification): a model instance
preprocessor (SentimentClassificationPreprocessor): a preprocessor instance
"""
assert isinstance(model, str) or isinstance(model, BertForZeroShotClassification), \
'model must be a single str or BertForZeroShotClassification'
sc_model = model if isinstance(
model,
BertForZeroShotClassification) else Model.from_pretrained(model)

self.entailment_id = 0
self.contradiction_id = 2
self.candidate_labels = kwargs.pop("candidate_labels")
self.hypothesis_template = kwargs.pop('hypothesis_template', "{}")
self.multi_label = kwargs.pop('multi_label', False)

if preprocessor is None:
preprocessor = ZeroShotClassificationPreprocessor(
sc_model.model_dir,
candidate_labels=self.candidate_labels,
hypothesis_template=self.hypothesis_template
)
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, Any]: the prediction results
"""

logits = inputs['logits']

if self.multi_label or len(self.candidate_labels) == 1:
logits = logits[..., [self.contradiction_id, self.entailment_id]]
scores = softmax(logits, axis=-1)[..., 1]
else:
logits = logits[..., self.entailment_id]
scores = softmax(logits, axis=-1)

reversed_index = list(reversed(scores.argsort()))
result = {
"labels": [self.candidate_labels[i] for i in reversed_index],
"scores": [scores[i].item() for i in reversed_index],
}
return result

+ 1
- 1
modelscope/preprocessors/__init__.py View File

@@ -5,4 +5,4 @@ from .builder import PREPROCESSORS, build_preprocessor
from .common import Compose
from .image import LoadImage, load_image
from .nlp import * # noqa F403
from .nlp import TextGenerationPreprocessor
from .nlp import TextGenerationPreprocessor, ZeroShotClassificationPreprocessor

+ 46
- 0
modelscope/preprocessors/nlp.py View File

@@ -147,3 +147,49 @@ class TextGenerationPreprocessor(Preprocessor):
rst['token_type_ids'].append(feature['token_type_ids'])

return {k: torch.tensor(v) for k, v in rst.items()}


@PREPROCESSORS.register_module(
Fields.nlp, module_name=r'bert-zero-shot-classification')
class ZeroShotClassificationPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""

super().__init__(*args, **kwargs)

from sofa import SbertTokenizer
self.model_dir: str = model_dir
self.sequence_length = kwargs.pop('sequence_length', 512)
self.candidate_labels = kwargs.pop("candidate_labels")
self.hypothesis_template = kwargs.pop('hypothesis_template', "{}")
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir)

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str): a sentence
Example:
'you are so handsome.'

Returns:
Dict[str, Any]: the preprocessed data
"""
pairs = [[data, self.hypothesis_template.format(label)] for label in self.candidate_labels]

features = self.tokenizer(
pairs,
padding=True,
truncation=True,
max_length=self.sequence_length,
return_tensors='pt',
truncation_strategy='only_first'
)
return features


+ 1
- 0
modelscope/utils/constant.py View File

@@ -30,6 +30,7 @@ class Tasks(object):
image_matting = 'image-matting'

# nlp tasks
zero_shot_classification = 'zero-shot-classification'
sentiment_analysis = 'sentiment-analysis'
text_classification = 'text-classification'
relation_extraction = 'relation-extraction'


+ 68
- 0
tests/pipelines/test_zero_shot_classification.py View File

@@ -0,0 +1,68 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from maas_hub.snapshot_download import snapshot_download

from modelscope.models import Model
from modelscope.models.nlp import BertForZeroShotClassification
from modelscope.pipelines import ZeroShotClassificationPipeline, pipeline
from modelscope.preprocessors import ZeroShotClassificationPreprocessor
from modelscope.utils.constant import Tasks


class ZeroShotClassificationTest(unittest.TestCase):
model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base'
sentence = '全新突破 解放军运20版空中加油机曝光'
candidate_labels = ["文化", "体育", "娱乐", "财经", "家居", "汽车", "教育", "科技", "军事"]

def test_run_from_local(self):
cache_path = snapshot_download(self.model_id)
tokenizer = ZeroShotClassificationPreprocessor(cache_path, candidate_labels=self.candidate_labels)
model = BertForZeroShotClassification(
cache_path, tokenizer=tokenizer)
pipeline1 = ZeroShotClassificationPipeline(
model,
preprocessor=tokenizer,
candidate_labels=self.candidate_labels,
)
pipeline2 = pipeline(
Tasks.zero_shot_classification,
model=model,
preprocessor=tokenizer,
candidate_labels=self.candidate_labels
)

print(f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence)}')
print()
print(f'sentence: {self.sentence}\n'
f'pipeline2: {pipeline2(input=self.sentence)}')

def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = ZeroShotClassificationPreprocessor(model.model_dir, candidate_labels=self.candidate_labels)
pipeline_ins = pipeline(
task=Tasks.zero_shot_classification,
model=model,
preprocessor=tokenizer,
candidate_labels=self.candidate_labels
)
print(pipeline_ins(input=self.sentence))

def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.zero_shot_classification,
model=self.model_id,
candidate_labels=self.candidate_labels
)
print(pipeline_ins(input=self.sentence))

def test_run_with_default_model(self):
pipeline_ins = pipeline(
task=Tasks.zero_shot_classification,
candidate_labels=self.candidate_labels)
print(pipeline_ins(input=self.sentence))


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save