Browse Source

[to #42322933] pre commit fix

master
易相 3 years ago
parent
commit
2c77fba805
7 changed files with 30 additions and 32 deletions
  1. +1
    -1
      modelscope/models/nlp/__init__.py
  2. +5
    -3
      modelscope/models/nlp/zero_shot_classification_model.py
  3. +2
    -1
      modelscope/pipelines/builder.py
  4. +1
    -1
      modelscope/pipelines/nlp/__init__.py
  5. +6
    -7
      modelscope/pipelines/nlp/zero_shot_classification_pipeline.py
  6. +6
    -8
      modelscope/preprocessors/nlp.py
  7. +9
    -11
      tests/pipelines/test_zero_shot_classification.py

+ 1
- 1
modelscope/models/nlp/__init__.py View File

@@ -1,4 +1,4 @@
from .sentence_similarity_model import * # noqa F403
from .sequence_classification_model import * # noqa F403
from .text_generation_model import * # noqa F403
from .zero_shot_classification_model import *
from .zero_shot_classification_model import * # noqa F403

+ 5
- 3
modelscope/models/nlp/zero_shot_classification_model.py View File

@@ -1,6 +1,7 @@
from typing import Any, Dict
import torch
import numpy as np
import torch

from modelscope.utils.constant import Tasks
from ..base import Model
@@ -10,7 +11,8 @@ __all__ = ['BertForZeroShotClassification']


@MODELS.register_module(
Tasks.zero_shot_classification, module_name=r'bert-zero-shot-classification')
Tasks.zero_shot_classification,
module_name=r'bert-zero-shot-classification')
class BertForZeroShotClassification(Model):

def __init__(self, model_dir: str, *args, **kwargs):
@@ -40,6 +42,6 @@ class BertForZeroShotClassification(Model):
"""
with torch.no_grad():
outputs = self.model(**input)
logits = outputs["logits"].numpy()
logits = outputs['logits'].numpy()
res = {'logits': logits}
return res

+ 2
- 1
modelscope/pipelines/builder.py View File

@@ -20,7 +20,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.text_classification:
('bert-sentiment-analysis', 'damo/bert-base-sst2'),
Tasks.zero_shot_classification:
('bert-zero-shot-classification', 'damo/nlp_structbert_zero-shot-classification_chinese-base'),
('bert-zero-shot-classification',
'damo/nlp_structbert_zero-shot-classification_chinese-base'),
Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'),
Tasks.image_captioning: ('ofa', None),
Tasks.image_generation:


+ 1
- 1
modelscope/pipelines/nlp/__init__.py View File

@@ -1,4 +1,4 @@
from .sentence_similarity_pipeline import * # noqa F403
from .sequence_classification_pipeline import * # noqa F403
from .text_generation_pipeline import * # noqa F403
from .zero_shot_classification_pipeline import *
from .zero_shot_classification_pipeline import * # noqa F403

+ 6
- 7
modelscope/pipelines/nlp/zero_shot_classification_pipeline.py View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, Union

import json
import numpy as np
from scipy.special import softmax

from modelscope.models.nlp import BertForZeroShotClassification
from modelscope.preprocessors import ZeroShotClassificationPreprocessor
@@ -11,7 +12,6 @@ from modelscope.utils.constant import Tasks
from ...models import Model
from ..base import Input, Pipeline
from ..builder import PIPELINES
from scipy.special import softmax

__all__ = ['ZeroShotClassificationPipeline']

@@ -39,16 +39,15 @@ class ZeroShotClassificationPipeline(Pipeline):

self.entailment_id = 0
self.contradiction_id = 2
self.candidate_labels = kwargs.pop("candidate_labels")
self.hypothesis_template = kwargs.pop('hypothesis_template', "{}")
self.candidate_labels = kwargs.pop('candidate_labels')
self.hypothesis_template = kwargs.pop('hypothesis_template', '{}')
self.multi_label = kwargs.pop('multi_label', False)

if preprocessor is None:
preprocessor = ZeroShotClassificationPreprocessor(
sc_model.model_dir,
candidate_labels=self.candidate_labels,
hypothesis_template=self.hypothesis_template
)
hypothesis_template=self.hypothesis_template)
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
@@ -72,7 +71,7 @@ class ZeroShotClassificationPipeline(Pipeline):

reversed_index = list(reversed(scores.argsort()))
result = {
"labels": [self.candidate_labels[i] for i in reversed_index],
"scores": [scores[i].item() for i in reversed_index],
'labels': [self.candidate_labels[i] for i in reversed_index],
'scores': [scores[i].item() for i in reversed_index],
}
return result

+ 6
- 8
modelscope/preprocessors/nlp.py View File

@@ -12,8 +12,7 @@ from .builder import PREPROCESSORS

__all__ = [
'Tokenize', 'SequenceClassificationPreprocessor',
'TextGenerationPreprocessor',
"ZeroShotClassificationPreprocessor"
'TextGenerationPreprocessor', 'ZeroShotClassificationPreprocessor'
]


@@ -190,8 +189,8 @@ class ZeroShotClassificationPreprocessor(Preprocessor):
from sofa import SbertTokenizer
self.model_dir: str = model_dir
self.sequence_length = kwargs.pop('sequence_length', 512)
self.candidate_labels = kwargs.pop("candidate_labels")
self.hypothesis_template = kwargs.pop('hypothesis_template', "{}")
self.candidate_labels = kwargs.pop('candidate_labels')
self.hypothesis_template = kwargs.pop('hypothesis_template', '{}')
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir)

@type_assert(object, str)
@@ -206,7 +205,8 @@ class ZeroShotClassificationPreprocessor(Preprocessor):
Returns:
Dict[str, Any]: the preprocessed data
"""
pairs = [[data, self.hypothesis_template.format(label)] for label in self.candidate_labels]
pairs = [[data, self.hypothesis_template.format(label)]
for label in self.candidate_labels]

features = self.tokenizer(
pairs,
@@ -214,7 +214,5 @@ class ZeroShotClassificationPreprocessor(Preprocessor):
truncation=True,
max_length=self.sequence_length,
return_tensors='pt',
truncation_strategy='only_first'
)
truncation_strategy='only_first')
return features


+ 9
- 11
tests/pipelines/test_zero_shot_classification.py View File

@@ -13,13 +13,13 @@ from modelscope.utils.constant import Tasks
class ZeroShotClassificationTest(unittest.TestCase):
model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base'
sentence = '全新突破 解放军运20版空中加油机曝光'
candidate_labels = ["文化", "体育", "娱乐", "财经", "家居", "汽车", "教育", "科技", "军事"]
candidate_labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']

def test_run_from_local(self):
cache_path = snapshot_download(self.model_id)
tokenizer = ZeroShotClassificationPreprocessor(cache_path, candidate_labels=self.candidate_labels)
model = BertForZeroShotClassification(
cache_path, tokenizer=tokenizer)
tokenizer = ZeroShotClassificationPreprocessor(
cache_path, candidate_labels=self.candidate_labels)
model = BertForZeroShotClassification(cache_path, tokenizer=tokenizer)
pipeline1 = ZeroShotClassificationPipeline(
model,
preprocessor=tokenizer,
@@ -29,8 +29,7 @@ class ZeroShotClassificationTest(unittest.TestCase):
Tasks.zero_shot_classification,
model=model,
preprocessor=tokenizer,
candidate_labels=self.candidate_labels
)
candidate_labels=self.candidate_labels)

print(f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence)}')
@@ -40,21 +39,20 @@ class ZeroShotClassificationTest(unittest.TestCase):

def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = ZeroShotClassificationPreprocessor(model.model_dir, candidate_labels=self.candidate_labels)
tokenizer = ZeroShotClassificationPreprocessor(
model.model_dir, candidate_labels=self.candidate_labels)
pipeline_ins = pipeline(
task=Tasks.zero_shot_classification,
model=model,
preprocessor=tokenizer,
candidate_labels=self.candidate_labels
)
candidate_labels=self.candidate_labels)
print(pipeline_ins(input=self.sentence))

def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.zero_shot_classification,
model=self.model_id,
candidate_labels=self.candidate_labels
)
candidate_labels=self.candidate_labels)
print(pipeline_ins(input=self.sentence))

def test_run_with_default_model(self):


Loading…
Cancel
Save