Browse Source

[to #42322933] pre commit fix

master
易相 3 years ago
parent
commit
2c77fba805
7 changed files with 30 additions and 32 deletions
  1. +1
    -1
      modelscope/models/nlp/__init__.py
  2. +5
    -3
      modelscope/models/nlp/zero_shot_classification_model.py
  3. +2
    -1
      modelscope/pipelines/builder.py
  4. +1
    -1
      modelscope/pipelines/nlp/__init__.py
  5. +6
    -7
      modelscope/pipelines/nlp/zero_shot_classification_pipeline.py
  6. +6
    -8
      modelscope/preprocessors/nlp.py
  7. +9
    -11
      tests/pipelines/test_zero_shot_classification.py

+ 1
- 1
modelscope/models/nlp/__init__.py View File

@@ -1,4 +1,4 @@
from .sentence_similarity_model import * # noqa F403 from .sentence_similarity_model import * # noqa F403
from .sequence_classification_model import * # noqa F403 from .sequence_classification_model import * # noqa F403
from .text_generation_model import * # noqa F403 from .text_generation_model import * # noqa F403
from .zero_shot_classification_model import *
from .zero_shot_classification_model import * # noqa F403

+ 5
- 3
modelscope/models/nlp/zero_shot_classification_model.py View File

@@ -1,6 +1,7 @@
from typing import Any, Dict from typing import Any, Dict
import torch
import numpy as np import numpy as np
import torch


from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from ..base import Model from ..base import Model
@@ -10,7 +11,8 @@ __all__ = ['BertForZeroShotClassification']




@MODELS.register_module( @MODELS.register_module(
Tasks.zero_shot_classification, module_name=r'bert-zero-shot-classification')
Tasks.zero_shot_classification,
module_name=r'bert-zero-shot-classification')
class BertForZeroShotClassification(Model): class BertForZeroShotClassification(Model):


def __init__(self, model_dir: str, *args, **kwargs): def __init__(self, model_dir: str, *args, **kwargs):
@@ -40,6 +42,6 @@ class BertForZeroShotClassification(Model):
""" """
with torch.no_grad(): with torch.no_grad():
outputs = self.model(**input) outputs = self.model(**input)
logits = outputs["logits"].numpy()
logits = outputs['logits'].numpy()
res = {'logits': logits} res = {'logits': logits}
return res return res

+ 2
- 1
modelscope/pipelines/builder.py View File

@@ -20,7 +20,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.text_classification: Tasks.text_classification:
('bert-sentiment-analysis', 'damo/bert-base-sst2'), ('bert-sentiment-analysis', 'damo/bert-base-sst2'),
Tasks.zero_shot_classification: Tasks.zero_shot_classification:
('bert-zero-shot-classification', 'damo/nlp_structbert_zero-shot-classification_chinese-base'),
('bert-zero-shot-classification',
'damo/nlp_structbert_zero-shot-classification_chinese-base'),
Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'),
Tasks.image_captioning: ('ofa', None), Tasks.image_captioning: ('ofa', None),
Tasks.image_generation: Tasks.image_generation:


+ 1
- 1
modelscope/pipelines/nlp/__init__.py View File

@@ -1,4 +1,4 @@
from .sentence_similarity_pipeline import * # noqa F403 from .sentence_similarity_pipeline import * # noqa F403
from .sequence_classification_pipeline import * # noqa F403 from .sequence_classification_pipeline import * # noqa F403
from .text_generation_pipeline import * # noqa F403 from .text_generation_pipeline import * # noqa F403
from .zero_shot_classification_pipeline import *
from .zero_shot_classification_pipeline import * # noqa F403

+ 6
- 7
modelscope/pipelines/nlp/zero_shot_classification_pipeline.py View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, Union


import json import json
import numpy as np import numpy as np
from scipy.special import softmax


from modelscope.models.nlp import BertForZeroShotClassification from modelscope.models.nlp import BertForZeroShotClassification
from modelscope.preprocessors import ZeroShotClassificationPreprocessor from modelscope.preprocessors import ZeroShotClassificationPreprocessor
@@ -11,7 +12,6 @@ from modelscope.utils.constant import Tasks
from ...models import Model from ...models import Model
from ..base import Input, Pipeline from ..base import Input, Pipeline
from ..builder import PIPELINES from ..builder import PIPELINES
from scipy.special import softmax


__all__ = ['ZeroShotClassificationPipeline'] __all__ = ['ZeroShotClassificationPipeline']


@@ -39,16 +39,15 @@ class ZeroShotClassificationPipeline(Pipeline):


self.entailment_id = 0 self.entailment_id = 0
self.contradiction_id = 2 self.contradiction_id = 2
self.candidate_labels = kwargs.pop("candidate_labels")
self.hypothesis_template = kwargs.pop('hypothesis_template', "{}")
self.candidate_labels = kwargs.pop('candidate_labels')
self.hypothesis_template = kwargs.pop('hypothesis_template', '{}')
self.multi_label = kwargs.pop('multi_label', False) self.multi_label = kwargs.pop('multi_label', False)


if preprocessor is None: if preprocessor is None:
preprocessor = ZeroShotClassificationPreprocessor( preprocessor = ZeroShotClassificationPreprocessor(
sc_model.model_dir, sc_model.model_dir,
candidate_labels=self.candidate_labels, candidate_labels=self.candidate_labels,
hypothesis_template=self.hypothesis_template
)
hypothesis_template=self.hypothesis_template)
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs) super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)


def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
@@ -72,7 +71,7 @@ class ZeroShotClassificationPipeline(Pipeline):


reversed_index = list(reversed(scores.argsort())) reversed_index = list(reversed(scores.argsort()))
result = { result = {
"labels": [self.candidate_labels[i] for i in reversed_index],
"scores": [scores[i].item() for i in reversed_index],
'labels': [self.candidate_labels[i] for i in reversed_index],
'scores': [scores[i].item() for i in reversed_index],
} }
return result return result

+ 6
- 8
modelscope/preprocessors/nlp.py View File

@@ -12,8 +12,7 @@ from .builder import PREPROCESSORS


__all__ = [ __all__ = [
'Tokenize', 'SequenceClassificationPreprocessor', 'Tokenize', 'SequenceClassificationPreprocessor',
'TextGenerationPreprocessor',
"ZeroShotClassificationPreprocessor"
'TextGenerationPreprocessor', 'ZeroShotClassificationPreprocessor'
] ]




@@ -190,8 +189,8 @@ class ZeroShotClassificationPreprocessor(Preprocessor):
from sofa import SbertTokenizer from sofa import SbertTokenizer
self.model_dir: str = model_dir self.model_dir: str = model_dir
self.sequence_length = kwargs.pop('sequence_length', 512) self.sequence_length = kwargs.pop('sequence_length', 512)
self.candidate_labels = kwargs.pop("candidate_labels")
self.hypothesis_template = kwargs.pop('hypothesis_template', "{}")
self.candidate_labels = kwargs.pop('candidate_labels')
self.hypothesis_template = kwargs.pop('hypothesis_template', '{}')
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir)


@type_assert(object, str) @type_assert(object, str)
@@ -206,7 +205,8 @@ class ZeroShotClassificationPreprocessor(Preprocessor):
Returns: Returns:
Dict[str, Any]: the preprocessed data Dict[str, Any]: the preprocessed data
""" """
pairs = [[data, self.hypothesis_template.format(label)] for label in self.candidate_labels]
pairs = [[data, self.hypothesis_template.format(label)]
for label in self.candidate_labels]


features = self.tokenizer( features = self.tokenizer(
pairs, pairs,
@@ -214,7 +214,5 @@ class ZeroShotClassificationPreprocessor(Preprocessor):
truncation=True, truncation=True,
max_length=self.sequence_length, max_length=self.sequence_length,
return_tensors='pt', return_tensors='pt',
truncation_strategy='only_first'
)
truncation_strategy='only_first')
return features return features


+ 9
- 11
tests/pipelines/test_zero_shot_classification.py View File

@@ -13,13 +13,13 @@ from modelscope.utils.constant import Tasks
class ZeroShotClassificationTest(unittest.TestCase): class ZeroShotClassificationTest(unittest.TestCase):
model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base' model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base'
sentence = '全新突破 解放军运20版空中加油机曝光' sentence = '全新突破 解放军运20版空中加油机曝光'
candidate_labels = ["文化", "体育", "娱乐", "财经", "家居", "汽车", "教育", "科技", "军事"]
candidate_labels = ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']


def test_run_from_local(self): def test_run_from_local(self):
cache_path = snapshot_download(self.model_id) cache_path = snapshot_download(self.model_id)
tokenizer = ZeroShotClassificationPreprocessor(cache_path, candidate_labels=self.candidate_labels)
model = BertForZeroShotClassification(
cache_path, tokenizer=tokenizer)
tokenizer = ZeroShotClassificationPreprocessor(
cache_path, candidate_labels=self.candidate_labels)
model = BertForZeroShotClassification(cache_path, tokenizer=tokenizer)
pipeline1 = ZeroShotClassificationPipeline( pipeline1 = ZeroShotClassificationPipeline(
model, model,
preprocessor=tokenizer, preprocessor=tokenizer,
@@ -29,8 +29,7 @@ class ZeroShotClassificationTest(unittest.TestCase):
Tasks.zero_shot_classification, Tasks.zero_shot_classification,
model=model, model=model,
preprocessor=tokenizer, preprocessor=tokenizer,
candidate_labels=self.candidate_labels
)
candidate_labels=self.candidate_labels)


print(f'sentence: {self.sentence}\n' print(f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence)}') f'pipeline1:{pipeline1(input=self.sentence)}')
@@ -40,21 +39,20 @@ class ZeroShotClassificationTest(unittest.TestCase):


def test_run_with_model_from_modelhub(self): def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id) model = Model.from_pretrained(self.model_id)
tokenizer = ZeroShotClassificationPreprocessor(model.model_dir, candidate_labels=self.candidate_labels)
tokenizer = ZeroShotClassificationPreprocessor(
model.model_dir, candidate_labels=self.candidate_labels)
pipeline_ins = pipeline( pipeline_ins = pipeline(
task=Tasks.zero_shot_classification, task=Tasks.zero_shot_classification,
model=model, model=model,
preprocessor=tokenizer, preprocessor=tokenizer,
candidate_labels=self.candidate_labels
)
candidate_labels=self.candidate_labels)
print(pipeline_ins(input=self.sentence)) print(pipeline_ins(input=self.sentence))


def test_run_with_model_name(self): def test_run_with_model_name(self):
pipeline_ins = pipeline( pipeline_ins = pipeline(
task=Tasks.zero_shot_classification, task=Tasks.zero_shot_classification,
model=self.model_id, model=self.model_id,
candidate_labels=self.candidate_labels
)
candidate_labels=self.candidate_labels)
print(pipeline_ins(input=self.sentence)) print(pipeline_ins(input=self.sentence))


def test_run_with_default_model(self): def test_run_with_default_model(self):


Loading…
Cancel
Save