Browse Source

[to #42322933] nlp preprocessor refactor

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9269314

    * init

* token to ids

* add model

* model forward ready

* add intent

* intent preprocessor ready

* intent success

* merge master

* test with model hub

* add flake8

* update

* update

* update

* Merge branch 'master' into nlp/space/gen

* delete file about gen

* init

* fix flake8 bug

* [to #42322933] init

* bug fix

* [to #42322933] init

* update pipeline registry info

* Merge remote-tracking branch 'origin/master' into feat/nli

* [to #42322933] init

* [to #42322933] init

* modify forward

* [to #42322933] init

* generation ready

* init

* Merge branch 'master' into feat/zero_shot_classification

# Conflicts:
#	modelscope/preprocessors/__init__.py

* [to #42322933] bugfix

* [to #42322933] pre commit fix

* fill mask

* registry multi models on model and pipeline

* add tests

* test level >= 0

* local gen ready

* merge with master

* dialog modeling ready

* fix comments: rename and refactor AliceMindMLM; adjust pipeline

* space intent and modeling(generation) are ready

* bug fix

* add dep

* add dep

* support dst data processor

* merge with nlp/space/dst

* merge with master

* Merge remote-tracking branch 'origin' into feat/fill_mask

Conflicts:
	modelscope/models/nlp/__init__.py
	modelscope/pipelines/builder.py
	modelscope/pipelines/outputs.py
	modelscope/preprocessors/nlp.py
	requirements/nlp.txt

* merge with master

* merge with master 2/2

* fix comments

* fix isort for pre-commit check

* allow params pass to pipeline's __call__ method

* Merge remote-tracking branch 'origin/master' into feat/zero_shot_classification

* merge with nli task

* merge with sentiment_classification

* merge with zero_shot_classfication

* merge with fill_mask

* merge with space

* merge with master head

* Merge remote-tracking branch 'origin' into feat/fill_mask

Conflicts:
	modelscope/utils/constant.py

* fix: pipeline module_name from model_type to 'fill_mask' & fix merge bug

* unfiinished change

* fix bug

* unfinished

* unfinished

* revise modelhub dependency

* Merge branch 'feat/nlp_refactor' of http://gitlab.alibaba-inc.com/Ali-MaaS/MaaS-lib into feat/nlp_refactor

* add eval() to pipeline call

* add test level

* ut run passed

* add default args

* tmp

* merge master

* all ut passed

* remove an useless enum

* revert a mis modification

* revert a mis modification

* Merge commit 'ace8af92465f7d772f035aebe98967726655f12c' into feat/nlp

* commit 'ace8af92465f7d772f035aebe98967726655f12c':
  [to #42322933] Add cv-action-recongnition-pipeline to maas lib
  [to #42463204]  support Pil.Image for image_captioning_pipeline
  [to #42670107] restore pydataset test
  [to #42322933] add create if not exist and add(back) create model example         Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9130661
  [to #41474818]fix: fix errors in task name definition

# Conflicts:
#	modelscope/pipelines/builder.py
#	modelscope/utils/constant.py

* Merge branch 'feat/nlp' into feat/nlp_refactor

* feat/nlp:
  [to #42322933] Add cv-action-recongnition-pipeline to maas lib
  [to #42463204]  support Pil.Image for image_captioning_pipeline
  [to #42670107] restore pydataset test
  [to #42322933] add create if not exist and add(back) create model example         Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9130661
  [to #41474818]fix: fix errors in task name definition

# Conflicts:
#	modelscope/pipelines/builder.py

* fix compile bug

* refactor space

* Merge branch 'feat/nlp_refactor' of http://gitlab.alibaba-inc.com/Ali-MaaS/MaaS-lib into feat/nlp_refactor

* Merge remote-tracking branch 'origin' into feat/fill_mask

* fix

* pre-commit lint

* lint file

* lint file

* lint file

* update modelhub dependency

* lint file

* ignore dst_processor temporary

* solve comment: 1. change MaskedLMModelBase to MaskedLanguageModelBase 2. remove a useless import

* recommit

* remove MaskedLanguageModel from __all__

* Merge commit '1a0d4af55a2eee69d89633874890f50eda8f8700' into feat/nlp_refactor

* commit '1a0d4af55a2eee69d89633874890f50eda8f8700':
  [to #42322933] test level check         Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9143809
  [to #42322933] update nlp models name in metainfo         Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9134657

# Conflicts:
#	modelscope/metainfo.py

* update

* revert pipeline params update

* remove zeroshot

* update sequence classfication outpus

* merge with fill mask

* Merge remote-tracking branch 'origin' into feat/fill_mask

* fix

* fix flake8 warning of dst

* Merge remote-tracking branch 'origin/feat/fill_mask' into feat/nlp

* merge with master

* remove useless test.py

* Merge remote-tracking branch 'origin/master' into feat/nlp

* remove unformatted space trainer

* revise based on comment except chinease comment

* skip ci blocking

* translation pipeline

* csanmt model for translation pipeline

* update

* update

* update builder.py

* change Chinese notes of space3.0 into English

* translate chinese comment to english

* add space to metainfo

* update casnmt_translation

* update csanmt transformer

* merge with master

* update csanmt translation

* update lint

* update metainfo.py

* Update translation_pipeline.py

* Update builder.py

* fix: 1. make csanmt derived from Model 2. add kwargs to prevent from call error

* pre-commit check

* temp exclue flake8

* temp ignore translation files

* fix bug

* pre-commit passed

* fixbug

* fixbug

* revert pre commit ignorance

* pre-commit passed

* fix bug

* merge with master

* add missing setting

* merge with master

* add outputs

* modify test level

* modify chinese comment

* remove useless doc

* space outputs normalization

* Merge remote-tracking branch 'origin/master' into nlp/translation

* update translation_pipeline.py

* Merge remote-tracking branch 'origin/master' into feat/nlp

* Merge remote-tracking branch 'origin/master' into nlp/translation

* add new __init__ method

* add new __init__ method

* update output format

* Merge remote-tracking branch 'origin/master' into feat/nlp

* update output format

* merge with master

* merge with nlp/translate

* update the translation comment

* update the translation comment

* Merge branch 'nlp/translation' into feat/nlp

* Merge remote-tracking branch 'origin/master' into feat/nlp

* Merge remote-tracking branch 'origin/master' into feat/nlp

* nlp preprocessor refactor

* add get_model_type in util.hub

* update the default preprocessor args

* update the fill mask preprocessor

* bug typo fixed
master
zhangzhicheng.zzc 3 years ago
parent
commit
cf194ef6cd
9 changed files with 130 additions and 267 deletions
  1. +2
    -2
      docs/source/tutorials/pipeline.md
  2. +1
    -0
      modelscope/metainfo.py
  3. +4
    -4
      modelscope/pipelines/nlp/sentence_similarity_pipeline.py
  4. +3
    -3
      modelscope/pipelines/nlp/text_generation_pipeline.py
  5. +8
    -7
      modelscope/pipelines/nlp/word_segmentation_pipeline.py
  6. +88
    -245
      modelscope/preprocessors/nlp.py
  7. +18
    -0
      modelscope/utils/hub.py
  8. +3
    -3
      tests/pipelines/test_sentence_similarity.py
  9. +3
    -3
      tests/pipelines/test_word_segmentation.py

+ 2
- 2
docs/source/tutorials/pipeline.md View File

@@ -37,9 +37,9 @@ pipeline函数支持传入实例化的预处理对象、模型对象,从而支
1. 首先,创建预处理方法和模型
```python
from modelscope.models import Model
from modelscope.preprocessors import TokenClassifcationPreprocessor
from modelscope.preprocessors import TokenClassificationPreprocessor
model = Model.from_pretrained('damo/nlp_structbert_word-segmentation_chinese-base')
tokenizer = TokenClassifcationPreprocessor(model.model_dir)
tokenizer = TokenClassificationPreprocessor(model.model_dir)
```

2. 使用tokenizer和模型对象创建pipeline


+ 1
- 0
modelscope/metainfo.py View File

@@ -106,6 +106,7 @@ class Preprocessors(object):
load_image = 'load-image'

# nlp preprocessor
sen_sim_tokenizer = 'sen-sim-tokenizer'
bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer'
palm_text_gen_tokenizer = 'palm-text-gen-tokenizer'
token_cls_tokenizer = 'token-cls-tokenizer'


+ 4
- 4
modelscope/pipelines/nlp/sentence_similarity_pipeline.py View File

@@ -6,7 +6,7 @@ import torch
from ...metainfo import Pipelines
from ...models import Model
from ...models.nlp import SbertForSentenceSimilarity
from ...preprocessors import SequenceClassificationPreprocessor
from ...preprocessors import SentenceSimilarityPreprocessor
from ...utils.constant import Tasks
from ..base import Input, Pipeline
from ..builder import PIPELINES
@@ -21,7 +21,7 @@ class SentenceSimilarityPipeline(Pipeline):

def __init__(self,
model: Union[Model, str],
preprocessor: SequenceClassificationPreprocessor = None,
preprocessor: SentenceSimilarityPreprocessor = None,
first_sequence='first_sequence',
second_sequence='second_sequence',
**kwargs):
@@ -29,7 +29,7 @@ class SentenceSimilarityPipeline(Pipeline):

Args:
model (SbertForSentenceSimilarity): a model instance
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
preprocessor (SentenceSimilarityPreprocessor): a preprocessor instance
"""
assert isinstance(model, str) or isinstance(model, SbertForSentenceSimilarity), \
'model must be a single str or SbertForSentenceSimilarity'
@@ -37,7 +37,7 @@ class SentenceSimilarityPipeline(Pipeline):
model,
SbertForSentenceSimilarity) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = SequenceClassificationPreprocessor(
preprocessor = SentenceSimilarityPreprocessor(
sc_model.model_dir,
first_sequence=first_sequence,
second_sequence=second_sequence)


+ 3
- 3
modelscope/pipelines/nlp/text_generation_pipeline.py View File

@@ -22,11 +22,11 @@ class TextGenerationPipeline(Pipeline):
model: Union[PalmForTextGeneration, str],
preprocessor: Optional[TextGenerationPreprocessor] = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction
"""use `model` and `preprocessor` to create a nlp text generation pipeline for prediction

Args:
model (SequenceClassificationModel): a model instance
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
model (PalmForTextGeneration): a model instance
preprocessor (TextGenerationPreprocessor): a preprocessor instance
"""
model = model if isinstance(
model, PalmForTextGeneration) else Model.from_pretrained(model)


+ 8
- 7
modelscope/pipelines/nlp/word_segmentation_pipeline.py View File

@@ -5,7 +5,7 @@ import torch
from ...metainfo import Pipelines
from ...models import Model
from ...models.nlp import SbertForTokenClassification
from ...preprocessors import TokenClassifcationPreprocessor
from ...preprocessors import TokenClassificationPreprocessor
from ...utils.constant import Tasks
from ..base import Pipeline, Tensor
from ..builder import PIPELINES
@@ -18,21 +18,22 @@ __all__ = ['WordSegmentationPipeline']
Tasks.word_segmentation, module_name=Pipelines.word_segmentation)
class WordSegmentationPipeline(Pipeline):

def __init__(self,
model: Union[SbertForTokenClassification, str],
preprocessor: Optional[TokenClassifcationPreprocessor] = None,
**kwargs):
def __init__(
self,
model: Union[SbertForTokenClassification, str],
preprocessor: Optional[TokenClassificationPreprocessor] = None,
**kwargs):
"""use `model` and `preprocessor` to create a nlp word segmentation pipeline for prediction

Args:
model (StructBertForTokenClassification): a model instance
preprocessor (TokenClassifcationPreprocessor): a preprocessor instance
preprocessor (TokenClassificationPreprocessor): a preprocessor instance
"""
model = model if isinstance(
model,
SbertForTokenClassification) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = TokenClassifcationPreprocessor(model.model_dir)
preprocessor = TokenClassificationPreprocessor(model.model_dir)
model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.tokenizer = preprocessor.tokenizer


+ 88
- 245
modelscope/preprocessors/nlp.py View File

@@ -5,7 +5,8 @@ from typing import Any, Dict, Union

from transformers import AutoTokenizer

from ..metainfo import Models, Preprocessors
from ..metainfo import Preprocessors
from ..models import Model
from ..utils.constant import Fields, InputFields
from ..utils.type_assert import type_assert
from .base import Preprocessor
@@ -13,9 +14,10 @@ from .builder import PREPROCESSORS

__all__ = [
'Tokenize', 'SequenceClassificationPreprocessor',
'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor',
'TextGenerationPreprocessor', 'TokenClassificationPreprocessor',
'NLIPreprocessor', 'SentimentClassificationPreprocessor',
'FillMaskPreprocessor', 'ZeroShotClassificationPreprocessor'
'FillMaskPreprocessor', 'SentenceSimilarityPreprocessor',
'ZeroShotClassificationPreprocessor'
]


@@ -33,9 +35,7 @@ class Tokenize(Preprocessor):
return data


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.nli_tokenizer)
class NLIPreprocessor(Preprocessor):
class NLPPreprocessorBase(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path
@@ -45,18 +45,19 @@ class NLIPreprocessor(Preprocessor):
"""

super().__init__(*args, **kwargs)

from sofa import SbertTokenizer
self.model_dir: str = model_dir
self.first_sequence: str = kwargs.pop('first_sequence',
'first_sequence')
self.second_sequence = kwargs.pop('second_sequence', 'second_sequence')
self.sequence_length = kwargs.pop('sequence_length', 128)
self.tokenize_kwargs = kwargs
self.tokenizer = self.build_tokenizer(model_dir)

self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir)
def build_tokenizer(self, model_dir):
from sofa import SbertTokenizer
return SbertTokenizer.from_pretrained(model_dir)

@type_assert(object, tuple)
def __call__(self, data: tuple) -> Dict[str, Any]:
@type_assert(object, object)
def __call__(self, data: Union[str, tuple, Dict]) -> Dict[str, Any]:
"""process the raw input data

Args:
@@ -70,101 +71,54 @@ class NLIPreprocessor(Preprocessor):
Returns:
Dict[str, Any]: the preprocessed data
"""
sentence1, sentence2 = data
new_data = {
self.first_sequence: sentence1,
self.second_sequence: sentence2
}
# preprocess the data for the model input

rst = {
'id': [],
'input_ids': [],
'attention_mask': [],
'token_type_ids': []
}
text_a, text_b = None, None
if isinstance(data, str):
text_a = data
elif isinstance(data, tuple):
assert len(data) == 2
text_a, text_b = data
elif isinstance(data, dict):
text_a = data.get(self.first_sequence)
text_b = data.get(self.second_sequence, None)

max_seq_length = self.sequence_length
return self.tokenizer(text_a, text_b, **self.tokenize_kwargs)

text_a = new_data[self.first_sequence]
text_b = new_data[self.second_sequence]
feature = self.tokenizer(
text_a,
text_b,
padding=False,
truncation=True,
max_length=max_seq_length)

rst['id'].append(new_data.get('id', str(uuid.uuid4())))
rst['input_ids'].append(feature['input_ids'])
rst['attention_mask'].append(feature['attention_mask'])
rst['token_type_ids'].append(feature['token_type_ids'])
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.nli_tokenizer)
class NLIPreprocessor(NLPPreprocessorBase):

return rst
def __init__(self, model_dir: str, *args, **kwargs):
kwargs['truncation'] = True
kwargs['padding'] = False
kwargs['return_tensors'] = 'pt'
kwargs['max_length'] = kwargs.pop('sequence_length', 128)
super().__init__(model_dir, *args, **kwargs)


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.sen_cls_tokenizer)
class SentimentClassificationPreprocessor(Preprocessor):
class SentimentClassificationPreprocessor(NLPPreprocessorBase):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""

super().__init__(*args, **kwargs)

from sofa import SbertTokenizer
self.model_dir: str = model_dir
self.first_sequence: str = kwargs.pop('first_sequence',
'first_sequence')
self.second_sequence = kwargs.pop('second_sequence', 'second_sequence')
self.sequence_length = kwargs.pop('sequence_length', 128)

self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir)

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data
kwargs['truncation'] = True
kwargs['padding'] = 'max_length'
kwargs['return_tensors'] = 'pt'
kwargs['max_length'] = kwargs.pop('sequence_length', 128)
super().__init__(model_dir, *args, **kwargs)

Args:
data (str): a sentence
Example:
'you are so handsome.'
Returns:
Dict[str, Any]: the preprocessed data
"""

new_data = {self.first_sequence: data}
# preprocess the data for the model input

rst = {
'id': [],
'input_ids': [],
'attention_mask': [],
'token_type_ids': []
}

max_seq_length = self.sequence_length

text_a = new_data[self.first_sequence]

text_b = new_data.get(self.second_sequence, None)
feature = self.tokenizer(
text_a,
text_b,
padding='max_length',
truncation=True,
max_length=max_seq_length)

rst['id'].append(new_data.get('id', str(uuid.uuid4())))
rst['input_ids'].append(feature['input_ids'])
rst['attention_mask'].append(feature['attention_mask'])
rst['token_type_ids'].append(feature['token_type_ids'])
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.sen_sim_tokenizer)
class SentenceSimilarityPreprocessor(NLPPreprocessorBase):

return rst
def __init__(self, model_dir: str, *args, **kwargs):
kwargs['truncation'] = True
kwargs['padding'] = False
kwargs['return_tensors'] = 'pt'
kwargs['max_length'] = kwargs.pop('sequence_length', 128)
super().__init__(model_dir, *args, **kwargs)


@PREPROCESSORS.register_module(
@@ -192,36 +146,7 @@ class SequenceClassificationPreprocessor(Preprocessor):

@type_assert(object, (str, tuple, Dict))
def __call__(self, data: Union[str, tuple, Dict]) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str or tuple, Dict):
sentence1 (str): a sentence
Example:
'you are so handsome.'
or
(sentence1, sentence2)
sentence1 (str): a sentence
Example:
'you are so handsome.'
sentence2 (str): a sentence
Example:
'you are so beautiful.'
or
{field1: field_value1, field2: field_value2}
field1 (str): field name, default 'first_sequence'
field_value1 (str): a sentence
Example:
'you are so handsome.'

field2 (str): field name, default 'second_sequence'
field_value2 (str): a sentence
Example:
'you are so beautiful.'

Returns:
Dict[str, Any]: the preprocessed data
"""
feature = super().__call__(data)
if isinstance(data, str):
new_data = {self.first_sequence: data}
elif isinstance(data, tuple):
@@ -263,136 +188,55 @@ class SequenceClassificationPreprocessor(Preprocessor):

@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.palm_text_gen_tokenizer)
class TextGenerationPreprocessor(Preprocessor):
class TextGenerationPreprocessor(NLPPreprocessorBase):

def __init__(self, model_dir: str, tokenizer, *args, **kwargs):
"""preprocess the data using the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""
super().__init__(*args, **kwargs)

self.model_dir: str = model_dir
self.first_sequence: str = kwargs.pop('first_sequence',
'first_sequence')
self.second_sequence: str = kwargs.pop('second_sequence',
'second_sequence')
self.sequence_length: int = kwargs.pop('sequence_length', 128)
self.tokenizer = tokenizer
kwargs['truncation'] = True
kwargs['padding'] = 'max_length'
kwargs['return_tensors'] = 'pt'
kwargs['return_token_type_ids'] = False
kwargs['max_length'] = kwargs.pop('sequence_length', 128)
super().__init__(model_dir, *args, **kwargs)

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str): a sentence
Example:
'you are so handsome.'

Returns:
Dict[str, Any]: the preprocessed data
"""
import torch

new_data = {self.first_sequence: data}
# preprocess the data for the model input

rst = {'input_ids': [], 'attention_mask': []}

max_seq_length = self.sequence_length

text_a = new_data.get(self.first_sequence, None)
text_b = new_data.get(self.second_sequence, None)
feature = self.tokenizer(
text_a,
text_b,
padding='max_length',
truncation=True,
max_length=max_seq_length)

rst['input_ids'].append(feature['input_ids'])
rst['attention_mask'].append(feature['attention_mask'])
return {k: torch.tensor(v) for k, v in rst.items()}
def build_tokenizer(self, model_dir):
return self.tokenizer


@PREPROCESSORS.register_module(Fields.nlp)
class FillMaskPreprocessor(Preprocessor):
class FillMaskPreprocessor(NLPPreprocessorBase):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""
super().__init__(*args, **kwargs)
self.model_dir = model_dir
self.first_sequence: str = kwargs.pop('first_sequence',
'first_sequence')
self.sequence_length = kwargs.pop('sequence_length', 128)
try:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
except KeyError:
from sofa.utils.backend import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
model_dir, use_fast=False)

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str): a sentence
Example:
'you are so handsome.'

Returns:
Dict[str, Any]: the preprocessed data
"""
import torch

new_data = {self.first_sequence: data}
# preprocess the data for the model input

rst = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []}

max_seq_length = self.sequence_length

text_a = new_data[self.first_sequence]
feature = self.tokenizer(
text_a,
padding='max_length',
truncation=True,
max_length=max_seq_length,
return_token_type_ids=True)

rst['input_ids'].append(feature['input_ids'])
rst['attention_mask'].append(feature['attention_mask'])
rst['token_type_ids'].append(feature['token_type_ids'])

return {k: torch.tensor(v) for k, v in rst.items()}
kwargs['truncation'] = True
kwargs['padding'] = 'max_length'
kwargs['return_tensors'] = 'pt'
kwargs['max_length'] = kwargs.pop('sequence_length', 128)
kwargs['return_token_type_ids'] = True
super().__init__(model_dir, *args, **kwargs)

def build_tokenizer(self, model_dir):
from ..utils.hub import get_model_type
model_type = get_model_type(model_dir)
if model_type in ['sbert', 'structbert', 'bert']:
from sofa import SbertTokenizer
return SbertTokenizer.from_pretrained(model_dir, use_fast=False)
elif model_type == 'veco':
from sofa import VecoTokenizer
return VecoTokenizer.from_pretrained(model_dir, use_fast=False)
else:
# TODO Only support veco & sbert
raise RuntimeError(f'Unsupported model type: {model_type}')


@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.token_cls_tokenizer)
class TokenClassifcationPreprocessor(Preprocessor):
class TokenClassificationPreprocessor(NLPPreprocessorBase):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""

super().__init__(*args, **kwargs)

from sofa import SbertTokenizer
self.model_dir: str = model_dir
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir)
super().__init__(model_dir, *args, **kwargs)

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
def __call__(self, data: Union[str, Dict]) -> Dict[str, Any]:
"""process the raw input data

Args:
@@ -405,7 +249,8 @@ class TokenClassifcationPreprocessor(Preprocessor):
"""

# preprocess the data for the model input

if isinstance(data, dict):
data = data[self.first_sequence]
text = data.replace(' ', '').strip()
tokens = []
for token in text:
@@ -425,7 +270,7 @@ class TokenClassifcationPreprocessor(Preprocessor):

@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer)
class ZeroShotClassificationPreprocessor(Preprocessor):
class ZeroShotClassificationPreprocessor(NLPPreprocessorBase):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path
@@ -433,16 +278,11 @@ class ZeroShotClassificationPreprocessor(Preprocessor):
Args:
model_dir (str): model path
"""

super().__init__(*args, **kwargs)

from sofa import SbertTokenizer
self.model_dir: str = model_dir
self.sequence_length = kwargs.pop('sequence_length', 512)
self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir)
super().__init__(model_dir, *args, **kwargs)

@type_assert(object, str)
def __call__(self, data: str, hypothesis_template: str,
def __call__(self, data, hypothesis_template: str,
candidate_labels: list) -> Dict[str, Any]:
"""process the raw input data

@@ -454,6 +294,9 @@ class ZeroShotClassificationPreprocessor(Preprocessor):
Returns:
Dict[str, Any]: the preprocessed data
"""
if isinstance(data, dict):
data = data.get(self.first_sequence)

pairs = [[data, hypothesis_template.format(label)]
for label in candidate_labels]



+ 18
- 0
modelscope/utils/hub.py View File

@@ -11,6 +11,9 @@ from modelscope.hub.file_download import model_file_download
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
from .logger import get_logger

logger = get_logger(__name__)


def create_model_if_not_exist(
@@ -67,3 +70,18 @@ def auto_load(model: Union[str, List[str]]):
]

return model


def get_model_type(model_dir):
try:
configuration_file = osp.join(model_dir, ModelFile.CONFIGURATION)
config_file = osp.join(model_dir, 'config.json')
if osp.isfile(configuration_file):
cfg = Config.from_file(configuration_file)
return cfg.model.model_type if hasattr(cfg.model, 'model_type') and not hasattr(cfg.model, 'type') \
else cfg.model.type
elif osp.isfile(config_file):
cfg = Config.from_file(config_file)
return cfg.model_type if hasattr(cfg, 'model_type') else None
except Exception as e:
logger.error(f'parse config file failed with error: {e}')

+ 3
- 3
tests/pipelines/test_sentence_similarity.py View File

@@ -6,7 +6,7 @@ from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SbertForSentenceSimilarity
from modelscope.pipelines import SentenceSimilarityPipeline, pipeline
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.preprocessors import SentenceSimilarityPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level

@@ -19,7 +19,7 @@ class SentenceSimilarityTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):
cache_path = snapshot_download(self.model_id)
tokenizer = SequenceClassificationPreprocessor(cache_path)
tokenizer = SentenceSimilarityPreprocessor(cache_path)
model = SbertForSentenceSimilarity(cache_path, tokenizer=tokenizer)
pipeline1 = SentenceSimilarityPipeline(model, preprocessor=tokenizer)
pipeline2 = pipeline(
@@ -35,7 +35,7 @@ class SentenceSimilarityTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = SequenceClassificationPreprocessor(model.model_dir)
tokenizer = SentenceSimilarityPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.sentence_similarity,
model=model,


+ 3
- 3
tests/pipelines/test_word_segmentation.py View File

@@ -6,7 +6,7 @@ from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.nlp import SbertForTokenClassification
from modelscope.pipelines import WordSegmentationPipeline, pipeline
from modelscope.preprocessors import TokenClassifcationPreprocessor
from modelscope.preprocessors import TokenClassificationPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level

@@ -18,7 +18,7 @@ class WordSegmentationTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
tokenizer = TokenClassifcationPreprocessor(cache_path)
tokenizer = TokenClassificationPreprocessor(cache_path)
model = SbertForTokenClassification(cache_path, tokenizer=tokenizer)
pipeline1 = WordSegmentationPipeline(model, preprocessor=tokenizer)
pipeline2 = pipeline(
@@ -31,7 +31,7 @@ class WordSegmentationTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
tokenizer = TokenClassifcationPreprocessor(model.model_dir)
tokenizer = TokenClassificationPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.word_segmentation, model=model, preprocessor=tokenizer)
print(pipeline_ins(input=self.sentence))


Loading…
Cancel
Save