From cf194ef6cd7b9c0d72fa37b400c0ab0580304171 Mon Sep 17 00:00:00 2001 From: "zhangzhicheng.zzc" Date: Tue, 5 Jul 2022 20:40:48 +0800 Subject: [PATCH] [to #42322933] nlp preprocessor refactor Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9269314 * init * token to ids * add model * model forward ready * add intent * intent preprocessor ready * intent success * merge master * test with model hub * add flake8 * update * update * update * Merge branch 'master' into nlp/space/gen * delete file about gen * init * fix flake8 bug * [to #42322933] init * bug fix * [to #42322933] init * update pipeline registry info * Merge remote-tracking branch 'origin/master' into feat/nli * [to #42322933] init * [to #42322933] init * modify forward * [to #42322933] init * generation ready * init * Merge branch 'master' into feat/zero_shot_classification # Conflicts: # modelscope/preprocessors/__init__.py * [to #42322933] bugfix * [to #42322933] pre commit fix * fill mask * registry multi models on model and pipeline * add tests * test level >= 0 * local gen ready * merge with master * dialog modeling ready * fix comments: rename and refactor AliceMindMLM; adjust pipeline * space intent and modeling(generation) are ready * bug fix * add dep * add dep * support dst data processor * merge with nlp/space/dst * merge with master * Merge remote-tracking branch 'origin' into feat/fill_mask Conflicts: modelscope/models/nlp/__init__.py modelscope/pipelines/builder.py modelscope/pipelines/outputs.py modelscope/preprocessors/nlp.py requirements/nlp.txt * merge with master * merge with master 2/2 * fix comments * fix isort for pre-commit check * allow params pass to pipeline's __call__ method * Merge remote-tracking branch 'origin/master' into feat/zero_shot_classification * merge with nli task * merge with sentiment_classification * merge with zero_shot_classfication * merge with fill_mask * merge with space * merge with master head * Merge remote-tracking branch 'origin' into feat/fill_mask Conflicts: modelscope/utils/constant.py * fix: pipeline module_name from model_type to 'fill_mask' & fix merge bug * unfiinished change * fix bug * unfinished * unfinished * revise modelhub dependency * Merge branch 'feat/nlp_refactor' of http://gitlab.alibaba-inc.com/Ali-MaaS/MaaS-lib into feat/nlp_refactor * add eval() to pipeline call * add test level * ut run passed * add default args * tmp * merge master * all ut passed * remove an useless enum * revert a mis modification * revert a mis modification * Merge commit 'ace8af92465f7d772f035aebe98967726655f12c' into feat/nlp * commit 'ace8af92465f7d772f035aebe98967726655f12c': [to #42322933] Add cv-action-recongnition-pipeline to maas lib [to #42463204] support Pil.Image for image_captioning_pipeline [to #42670107] restore pydataset test [to #42322933] add create if not exist and add(back) create model example Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9130661 [to #41474818]fix: fix errors in task name definition # Conflicts: # modelscope/pipelines/builder.py # modelscope/utils/constant.py * Merge branch 'feat/nlp' into feat/nlp_refactor * feat/nlp: [to #42322933] Add cv-action-recongnition-pipeline to maas lib [to #42463204] support Pil.Image for image_captioning_pipeline [to #42670107] restore pydataset test [to #42322933] add create if not exist and add(back) create model example Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9130661 [to #41474818]fix: fix errors in task name definition # Conflicts: # modelscope/pipelines/builder.py * fix compile bug * refactor space * Merge branch 'feat/nlp_refactor' of http://gitlab.alibaba-inc.com/Ali-MaaS/MaaS-lib into feat/nlp_refactor * Merge remote-tracking branch 'origin' into feat/fill_mask * fix * pre-commit lint * lint file * lint file * lint file * update modelhub dependency * lint file * ignore dst_processor temporary * solve comment: 1. change MaskedLMModelBase to MaskedLanguageModelBase 2. remove a useless import * recommit * remove MaskedLanguageModel from __all__ * Merge commit '1a0d4af55a2eee69d89633874890f50eda8f8700' into feat/nlp_refactor * commit '1a0d4af55a2eee69d89633874890f50eda8f8700': [to #42322933] test level check Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9143809 [to #42322933] update nlp models name in metainfo Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9134657 # Conflicts: # modelscope/metainfo.py * update * revert pipeline params update * remove zeroshot * update sequence classfication outpus * merge with fill mask * Merge remote-tracking branch 'origin' into feat/fill_mask * fix * fix flake8 warning of dst * Merge remote-tracking branch 'origin/feat/fill_mask' into feat/nlp * merge with master * remove useless test.py * Merge remote-tracking branch 'origin/master' into feat/nlp * remove unformatted space trainer * revise based on comment except chinease comment * skip ci blocking * translation pipeline * csanmt model for translation pipeline * update * update * update builder.py * change Chinese notes of space3.0 into English * translate chinese comment to english * add space to metainfo * update casnmt_translation * update csanmt transformer * merge with master * update csanmt translation * update lint * update metainfo.py * Update translation_pipeline.py * Update builder.py * fix: 1. make csanmt derived from Model 2. add kwargs to prevent from call error * pre-commit check * temp exclue flake8 * temp ignore translation files * fix bug * pre-commit passed * fixbug * fixbug * revert pre commit ignorance * pre-commit passed * fix bug * merge with master * add missing setting * merge with master * add outputs * modify test level * modify chinese comment * remove useless doc * space outputs normalization * Merge remote-tracking branch 'origin/master' into nlp/translation * update translation_pipeline.py * Merge remote-tracking branch 'origin/master' into feat/nlp * Merge remote-tracking branch 'origin/master' into nlp/translation * add new __init__ method * add new __init__ method * update output format * Merge remote-tracking branch 'origin/master' into feat/nlp * update output format * merge with master * merge with nlp/translate * update the translation comment * update the translation comment * Merge branch 'nlp/translation' into feat/nlp * Merge remote-tracking branch 'origin/master' into feat/nlp * Merge remote-tracking branch 'origin/master' into feat/nlp * nlp preprocessor refactor * add get_model_type in util.hub * update the default preprocessor args * update the fill mask preprocessor * bug typo fixed --- docs/source/tutorials/pipeline.md | 4 +- modelscope/metainfo.py | 1 + .../nlp/sentence_similarity_pipeline.py | 8 +- .../pipelines/nlp/text_generation_pipeline.py | 6 +- .../nlp/word_segmentation_pipeline.py | 15 +- modelscope/preprocessors/nlp.py | 333 +++++------------- modelscope/utils/hub.py | 18 + tests/pipelines/test_sentence_similarity.py | 6 +- tests/pipelines/test_word_segmentation.py | 6 +- 9 files changed, 130 insertions(+), 267 deletions(-) diff --git a/docs/source/tutorials/pipeline.md b/docs/source/tutorials/pipeline.md index 2d1f18e2..ebdc06f3 100644 --- a/docs/source/tutorials/pipeline.md +++ b/docs/source/tutorials/pipeline.md @@ -37,9 +37,9 @@ pipeline函数支持传入实例化的预处理对象、模型对象,从而支 1. 首先,创建预处理方法和模型 ```python from modelscope.models import Model -from modelscope.preprocessors import TokenClassifcationPreprocessor +from modelscope.preprocessors import TokenClassificationPreprocessor model = Model.from_pretrained('damo/nlp_structbert_word-segmentation_chinese-base') -tokenizer = TokenClassifcationPreprocessor(model.model_dir) +tokenizer = TokenClassificationPreprocessor(model.model_dir) ``` 2. 使用tokenizer和模型对象创建pipeline diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 1f8440de..555de643 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -106,6 +106,7 @@ class Preprocessors(object): load_image = 'load-image' # nlp preprocessor + sen_sim_tokenizer = 'sen-sim-tokenizer' bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' palm_text_gen_tokenizer = 'palm-text-gen-tokenizer' token_cls_tokenizer = 'token-cls-tokenizer' diff --git a/modelscope/pipelines/nlp/sentence_similarity_pipeline.py b/modelscope/pipelines/nlp/sentence_similarity_pipeline.py index 4cccd996..c8484521 100644 --- a/modelscope/pipelines/nlp/sentence_similarity_pipeline.py +++ b/modelscope/pipelines/nlp/sentence_similarity_pipeline.py @@ -6,7 +6,7 @@ import torch from ...metainfo import Pipelines from ...models import Model from ...models.nlp import SbertForSentenceSimilarity -from ...preprocessors import SequenceClassificationPreprocessor +from ...preprocessors import SentenceSimilarityPreprocessor from ...utils.constant import Tasks from ..base import Input, Pipeline from ..builder import PIPELINES @@ -21,7 +21,7 @@ class SentenceSimilarityPipeline(Pipeline): def __init__(self, model: Union[Model, str], - preprocessor: SequenceClassificationPreprocessor = None, + preprocessor: SentenceSimilarityPreprocessor = None, first_sequence='first_sequence', second_sequence='second_sequence', **kwargs): @@ -29,7 +29,7 @@ class SentenceSimilarityPipeline(Pipeline): Args: model (SbertForSentenceSimilarity): a model instance - preprocessor (SequenceClassificationPreprocessor): a preprocessor instance + preprocessor (SentenceSimilarityPreprocessor): a preprocessor instance """ assert isinstance(model, str) or isinstance(model, SbertForSentenceSimilarity), \ 'model must be a single str or SbertForSentenceSimilarity' @@ -37,7 +37,7 @@ class SentenceSimilarityPipeline(Pipeline): model, SbertForSentenceSimilarity) else Model.from_pretrained(model) if preprocessor is None: - preprocessor = SequenceClassificationPreprocessor( + preprocessor = SentenceSimilarityPreprocessor( sc_model.model_dir, first_sequence=first_sequence, second_sequence=second_sequence) diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py index d5e9e58b..d383838e 100644 --- a/modelscope/pipelines/nlp/text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -22,11 +22,11 @@ class TextGenerationPipeline(Pipeline): model: Union[PalmForTextGeneration, str], preprocessor: Optional[TextGenerationPreprocessor] = None, **kwargs): - """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction + """use `model` and `preprocessor` to create a nlp text generation pipeline for prediction Args: - model (SequenceClassificationModel): a model instance - preprocessor (SequenceClassificationPreprocessor): a preprocessor instance + model (PalmForTextGeneration): a model instance + preprocessor (TextGenerationPreprocessor): a preprocessor instance """ model = model if isinstance( model, PalmForTextGeneration) else Model.from_pretrained(model) diff --git a/modelscope/pipelines/nlp/word_segmentation_pipeline.py b/modelscope/pipelines/nlp/word_segmentation_pipeline.py index 66b333cb..c220adbb 100644 --- a/modelscope/pipelines/nlp/word_segmentation_pipeline.py +++ b/modelscope/pipelines/nlp/word_segmentation_pipeline.py @@ -5,7 +5,7 @@ import torch from ...metainfo import Pipelines from ...models import Model from ...models.nlp import SbertForTokenClassification -from ...preprocessors import TokenClassifcationPreprocessor +from ...preprocessors import TokenClassificationPreprocessor from ...utils.constant import Tasks from ..base import Pipeline, Tensor from ..builder import PIPELINES @@ -18,21 +18,22 @@ __all__ = ['WordSegmentationPipeline'] Tasks.word_segmentation, module_name=Pipelines.word_segmentation) class WordSegmentationPipeline(Pipeline): - def __init__(self, - model: Union[SbertForTokenClassification, str], - preprocessor: Optional[TokenClassifcationPreprocessor] = None, - **kwargs): + def __init__( + self, + model: Union[SbertForTokenClassification, str], + preprocessor: Optional[TokenClassificationPreprocessor] = None, + **kwargs): """use `model` and `preprocessor` to create a nlp word segmentation pipeline for prediction Args: model (StructBertForTokenClassification): a model instance - preprocessor (TokenClassifcationPreprocessor): a preprocessor instance + preprocessor (TokenClassificationPreprocessor): a preprocessor instance """ model = model if isinstance( model, SbertForTokenClassification) else Model.from_pretrained(model) if preprocessor is None: - preprocessor = TokenClassifcationPreprocessor(model.model_dir) + preprocessor = TokenClassificationPreprocessor(model.model_dir) model.eval() super().__init__(model=model, preprocessor=preprocessor, **kwargs) self.tokenizer = preprocessor.tokenizer diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 007a3ac1..360d97aa 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -5,7 +5,8 @@ from typing import Any, Dict, Union from transformers import AutoTokenizer -from ..metainfo import Models, Preprocessors +from ..metainfo import Preprocessors +from ..models import Model from ..utils.constant import Fields, InputFields from ..utils.type_assert import type_assert from .base import Preprocessor @@ -13,9 +14,10 @@ from .builder import PREPROCESSORS __all__ = [ 'Tokenize', 'SequenceClassificationPreprocessor', - 'TextGenerationPreprocessor', 'TokenClassifcationPreprocessor', + 'TextGenerationPreprocessor', 'TokenClassificationPreprocessor', 'NLIPreprocessor', 'SentimentClassificationPreprocessor', - 'FillMaskPreprocessor', 'ZeroShotClassificationPreprocessor' + 'FillMaskPreprocessor', 'SentenceSimilarityPreprocessor', + 'ZeroShotClassificationPreprocessor' ] @@ -33,9 +35,7 @@ class Tokenize(Preprocessor): return data -@PREPROCESSORS.register_module( - Fields.nlp, module_name=Preprocessors.nli_tokenizer) -class NLIPreprocessor(Preprocessor): +class NLPPreprocessorBase(Preprocessor): def __init__(self, model_dir: str, *args, **kwargs): """preprocess the data via the vocab.txt from the `model_dir` path @@ -45,18 +45,19 @@ class NLIPreprocessor(Preprocessor): """ super().__init__(*args, **kwargs) - - from sofa import SbertTokenizer self.model_dir: str = model_dir self.first_sequence: str = kwargs.pop('first_sequence', 'first_sequence') self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') - self.sequence_length = kwargs.pop('sequence_length', 128) + self.tokenize_kwargs = kwargs + self.tokenizer = self.build_tokenizer(model_dir) - self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) + def build_tokenizer(self, model_dir): + from sofa import SbertTokenizer + return SbertTokenizer.from_pretrained(model_dir) - @type_assert(object, tuple) - def __call__(self, data: tuple) -> Dict[str, Any]: + @type_assert(object, object) + def __call__(self, data: Union[str, tuple, Dict]) -> Dict[str, Any]: """process the raw input data Args: @@ -70,101 +71,54 @@ class NLIPreprocessor(Preprocessor): Returns: Dict[str, Any]: the preprocessed data """ - sentence1, sentence2 = data - new_data = { - self.first_sequence: sentence1, - self.second_sequence: sentence2 - } - # preprocess the data for the model input - rst = { - 'id': [], - 'input_ids': [], - 'attention_mask': [], - 'token_type_ids': [] - } + text_a, text_b = None, None + if isinstance(data, str): + text_a = data + elif isinstance(data, tuple): + assert len(data) == 2 + text_a, text_b = data + elif isinstance(data, dict): + text_a = data.get(self.first_sequence) + text_b = data.get(self.second_sequence, None) - max_seq_length = self.sequence_length + return self.tokenizer(text_a, text_b, **self.tokenize_kwargs) - text_a = new_data[self.first_sequence] - text_b = new_data[self.second_sequence] - feature = self.tokenizer( - text_a, - text_b, - padding=False, - truncation=True, - max_length=max_seq_length) - rst['id'].append(new_data.get('id', str(uuid.uuid4()))) - rst['input_ids'].append(feature['input_ids']) - rst['attention_mask'].append(feature['attention_mask']) - rst['token_type_ids'].append(feature['token_type_ids']) +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.nli_tokenizer) +class NLIPreprocessor(NLPPreprocessorBase): - return rst + def __init__(self, model_dir: str, *args, **kwargs): + kwargs['truncation'] = True + kwargs['padding'] = False + kwargs['return_tensors'] = 'pt' + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + super().__init__(model_dir, *args, **kwargs) @PREPROCESSORS.register_module( Fields.nlp, module_name=Preprocessors.sen_cls_tokenizer) -class SentimentClassificationPreprocessor(Preprocessor): +class SentimentClassificationPreprocessor(NLPPreprocessorBase): def __init__(self, model_dir: str, *args, **kwargs): - """preprocess the data via the vocab.txt from the `model_dir` path - - Args: - model_dir (str): model path - """ - - super().__init__(*args, **kwargs) - - from sofa import SbertTokenizer - self.model_dir: str = model_dir - self.first_sequence: str = kwargs.pop('first_sequence', - 'first_sequence') - self.second_sequence = kwargs.pop('second_sequence', 'second_sequence') - self.sequence_length = kwargs.pop('sequence_length', 128) - - self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) - - @type_assert(object, str) - def __call__(self, data: str) -> Dict[str, Any]: - """process the raw input data + kwargs['truncation'] = True + kwargs['padding'] = 'max_length' + kwargs['return_tensors'] = 'pt' + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + super().__init__(model_dir, *args, **kwargs) - Args: - data (str): a sentence - Example: - 'you are so handsome.' - Returns: - Dict[str, Any]: the preprocessed data - """ - - new_data = {self.first_sequence: data} - # preprocess the data for the model input - rst = { - 'id': [], - 'input_ids': [], - 'attention_mask': [], - 'token_type_ids': [] - } - - max_seq_length = self.sequence_length - - text_a = new_data[self.first_sequence] - - text_b = new_data.get(self.second_sequence, None) - feature = self.tokenizer( - text_a, - text_b, - padding='max_length', - truncation=True, - max_length=max_seq_length) - - rst['id'].append(new_data.get('id', str(uuid.uuid4()))) - rst['input_ids'].append(feature['input_ids']) - rst['attention_mask'].append(feature['attention_mask']) - rst['token_type_ids'].append(feature['token_type_ids']) +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.sen_sim_tokenizer) +class SentenceSimilarityPreprocessor(NLPPreprocessorBase): - return rst + def __init__(self, model_dir: str, *args, **kwargs): + kwargs['truncation'] = True + kwargs['padding'] = False + kwargs['return_tensors'] = 'pt' + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + super().__init__(model_dir, *args, **kwargs) @PREPROCESSORS.register_module( @@ -192,36 +146,7 @@ class SequenceClassificationPreprocessor(Preprocessor): @type_assert(object, (str, tuple, Dict)) def __call__(self, data: Union[str, tuple, Dict]) -> Dict[str, Any]: - """process the raw input data - - Args: - data (str or tuple, Dict): - sentence1 (str): a sentence - Example: - 'you are so handsome.' - or - (sentence1, sentence2) - sentence1 (str): a sentence - Example: - 'you are so handsome.' - sentence2 (str): a sentence - Example: - 'you are so beautiful.' - or - {field1: field_value1, field2: field_value2} - field1 (str): field name, default 'first_sequence' - field_value1 (str): a sentence - Example: - 'you are so handsome.' - - field2 (str): field name, default 'second_sequence' - field_value2 (str): a sentence - Example: - 'you are so beautiful.' - - Returns: - Dict[str, Any]: the preprocessed data - """ + feature = super().__call__(data) if isinstance(data, str): new_data = {self.first_sequence: data} elif isinstance(data, tuple): @@ -263,136 +188,55 @@ class SequenceClassificationPreprocessor(Preprocessor): @PREPROCESSORS.register_module( Fields.nlp, module_name=Preprocessors.palm_text_gen_tokenizer) -class TextGenerationPreprocessor(Preprocessor): +class TextGenerationPreprocessor(NLPPreprocessorBase): def __init__(self, model_dir: str, tokenizer, *args, **kwargs): - """preprocess the data using the vocab.txt from the `model_dir` path - - Args: - model_dir (str): model path - """ - super().__init__(*args, **kwargs) - - self.model_dir: str = model_dir - self.first_sequence: str = kwargs.pop('first_sequence', - 'first_sequence') - self.second_sequence: str = kwargs.pop('second_sequence', - 'second_sequence') - self.sequence_length: int = kwargs.pop('sequence_length', 128) self.tokenizer = tokenizer + kwargs['truncation'] = True + kwargs['padding'] = 'max_length' + kwargs['return_tensors'] = 'pt' + kwargs['return_token_type_ids'] = False + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + super().__init__(model_dir, *args, **kwargs) - @type_assert(object, str) - def __call__(self, data: str) -> Dict[str, Any]: - """process the raw input data - - Args: - data (str): a sentence - Example: - 'you are so handsome.' - - Returns: - Dict[str, Any]: the preprocessed data - """ - import torch - - new_data = {self.first_sequence: data} - # preprocess the data for the model input - - rst = {'input_ids': [], 'attention_mask': []} - - max_seq_length = self.sequence_length - - text_a = new_data.get(self.first_sequence, None) - text_b = new_data.get(self.second_sequence, None) - feature = self.tokenizer( - text_a, - text_b, - padding='max_length', - truncation=True, - max_length=max_seq_length) - - rst['input_ids'].append(feature['input_ids']) - rst['attention_mask'].append(feature['attention_mask']) - return {k: torch.tensor(v) for k, v in rst.items()} + def build_tokenizer(self, model_dir): + return self.tokenizer @PREPROCESSORS.register_module(Fields.nlp) -class FillMaskPreprocessor(Preprocessor): +class FillMaskPreprocessor(NLPPreprocessorBase): def __init__(self, model_dir: str, *args, **kwargs): - """preprocess the data via the vocab.txt from the `model_dir` path - - Args: - model_dir (str): model path - """ - super().__init__(*args, **kwargs) - self.model_dir = model_dir - self.first_sequence: str = kwargs.pop('first_sequence', - 'first_sequence') - self.sequence_length = kwargs.pop('sequence_length', 128) - try: - from transformers import AutoTokenizer - self.tokenizer = AutoTokenizer.from_pretrained(model_dir) - except KeyError: - from sofa.utils.backend import AutoTokenizer - self.tokenizer = AutoTokenizer.from_pretrained( - model_dir, use_fast=False) - - @type_assert(object, str) - def __call__(self, data: str) -> Dict[str, Any]: - """process the raw input data - - Args: - data (str): a sentence - Example: - 'you are so handsome.' - - Returns: - Dict[str, Any]: the preprocessed data - """ - import torch - - new_data = {self.first_sequence: data} - # preprocess the data for the model input - - rst = {'input_ids': [], 'attention_mask': [], 'token_type_ids': []} - - max_seq_length = self.sequence_length - - text_a = new_data[self.first_sequence] - feature = self.tokenizer( - text_a, - padding='max_length', - truncation=True, - max_length=max_seq_length, - return_token_type_ids=True) - - rst['input_ids'].append(feature['input_ids']) - rst['attention_mask'].append(feature['attention_mask']) - rst['token_type_ids'].append(feature['token_type_ids']) - - return {k: torch.tensor(v) for k, v in rst.items()} + kwargs['truncation'] = True + kwargs['padding'] = 'max_length' + kwargs['return_tensors'] = 'pt' + kwargs['max_length'] = kwargs.pop('sequence_length', 128) + kwargs['return_token_type_ids'] = True + super().__init__(model_dir, *args, **kwargs) + + def build_tokenizer(self, model_dir): + from ..utils.hub import get_model_type + model_type = get_model_type(model_dir) + if model_type in ['sbert', 'structbert', 'bert']: + from sofa import SbertTokenizer + return SbertTokenizer.from_pretrained(model_dir, use_fast=False) + elif model_type == 'veco': + from sofa import VecoTokenizer + return VecoTokenizer.from_pretrained(model_dir, use_fast=False) + else: + # TODO Only support veco & sbert + raise RuntimeError(f'Unsupported model type: {model_type}') @PREPROCESSORS.register_module( Fields.nlp, module_name=Preprocessors.token_cls_tokenizer) -class TokenClassifcationPreprocessor(Preprocessor): +class TokenClassificationPreprocessor(NLPPreprocessorBase): def __init__(self, model_dir: str, *args, **kwargs): - """preprocess the data via the vocab.txt from the `model_dir` path - - Args: - model_dir (str): model path - """ - - super().__init__(*args, **kwargs) - - from sofa import SbertTokenizer - self.model_dir: str = model_dir - self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) + super().__init__(model_dir, *args, **kwargs) @type_assert(object, str) - def __call__(self, data: str) -> Dict[str, Any]: + def __call__(self, data: Union[str, Dict]) -> Dict[str, Any]: """process the raw input data Args: @@ -405,7 +249,8 @@ class TokenClassifcationPreprocessor(Preprocessor): """ # preprocess the data for the model input - + if isinstance(data, dict): + data = data[self.first_sequence] text = data.replace(' ', '').strip() tokens = [] for token in text: @@ -425,7 +270,7 @@ class TokenClassifcationPreprocessor(Preprocessor): @PREPROCESSORS.register_module( Fields.nlp, module_name=Preprocessors.zero_shot_cls_tokenizer) -class ZeroShotClassificationPreprocessor(Preprocessor): +class ZeroShotClassificationPreprocessor(NLPPreprocessorBase): def __init__(self, model_dir: str, *args, **kwargs): """preprocess the data via the vocab.txt from the `model_dir` path @@ -433,16 +278,11 @@ class ZeroShotClassificationPreprocessor(Preprocessor): Args: model_dir (str): model path """ - - super().__init__(*args, **kwargs) - - from sofa import SbertTokenizer - self.model_dir: str = model_dir self.sequence_length = kwargs.pop('sequence_length', 512) - self.tokenizer = SbertTokenizer.from_pretrained(self.model_dir) + super().__init__(model_dir, *args, **kwargs) @type_assert(object, str) - def __call__(self, data: str, hypothesis_template: str, + def __call__(self, data, hypothesis_template: str, candidate_labels: list) -> Dict[str, Any]: """process the raw input data @@ -454,6 +294,9 @@ class ZeroShotClassificationPreprocessor(Preprocessor): Returns: Dict[str, Any]: the preprocessed data """ + if isinstance(data, dict): + data = data.get(self.first_sequence) + pairs = [[data, hypothesis_template.format(label)] for label in candidate_labels] diff --git a/modelscope/utils/hub.py b/modelscope/utils/hub.py index 3b7e80ef..f2a3c120 100644 --- a/modelscope/utils/hub.py +++ b/modelscope/utils/hub.py @@ -11,6 +11,9 @@ from modelscope.hub.file_download import model_file_download from modelscope.hub.snapshot_download import snapshot_download from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile +from .logger import get_logger + +logger = get_logger(__name__) def create_model_if_not_exist( @@ -67,3 +70,18 @@ def auto_load(model: Union[str, List[str]]): ] return model + + +def get_model_type(model_dir): + try: + configuration_file = osp.join(model_dir, ModelFile.CONFIGURATION) + config_file = osp.join(model_dir, 'config.json') + if osp.isfile(configuration_file): + cfg = Config.from_file(configuration_file) + return cfg.model.model_type if hasattr(cfg.model, 'model_type') and not hasattr(cfg.model, 'type') \ + else cfg.model.type + elif osp.isfile(config_file): + cfg = Config.from_file(config_file) + return cfg.model_type if hasattr(cfg, 'model_type') else None + except Exception as e: + logger.error(f'parse config file failed with error: {e}') diff --git a/tests/pipelines/test_sentence_similarity.py b/tests/pipelines/test_sentence_similarity.py index df38593f..02edb87f 100644 --- a/tests/pipelines/test_sentence_similarity.py +++ b/tests/pipelines/test_sentence_similarity.py @@ -6,7 +6,7 @@ from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model from modelscope.models.nlp import SbertForSentenceSimilarity from modelscope.pipelines import SentenceSimilarityPipeline, pipeline -from modelscope.preprocessors import SequenceClassificationPreprocessor +from modelscope.preprocessors import SentenceSimilarityPreprocessor from modelscope.utils.constant import Tasks from modelscope.utils.test_utils import test_level @@ -19,7 +19,7 @@ class SentenceSimilarityTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run(self): cache_path = snapshot_download(self.model_id) - tokenizer = SequenceClassificationPreprocessor(cache_path) + tokenizer = SentenceSimilarityPreprocessor(cache_path) model = SbertForSentenceSimilarity(cache_path, tokenizer=tokenizer) pipeline1 = SentenceSimilarityPipeline(model, preprocessor=tokenizer) pipeline2 = pipeline( @@ -35,7 +35,7 @@ class SentenceSimilarityTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id) - tokenizer = SequenceClassificationPreprocessor(model.model_dir) + tokenizer = SentenceSimilarityPreprocessor(model.model_dir) pipeline_ins = pipeline( task=Tasks.sentence_similarity, model=model, diff --git a/tests/pipelines/test_word_segmentation.py b/tests/pipelines/test_word_segmentation.py index d33e4bdb..51f14011 100644 --- a/tests/pipelines/test_word_segmentation.py +++ b/tests/pipelines/test_word_segmentation.py @@ -6,7 +6,7 @@ from modelscope.hub.snapshot_download import snapshot_download from modelscope.models import Model from modelscope.models.nlp import SbertForTokenClassification from modelscope.pipelines import WordSegmentationPipeline, pipeline -from modelscope.preprocessors import TokenClassifcationPreprocessor +from modelscope.preprocessors import TokenClassificationPreprocessor from modelscope.utils.constant import Tasks from modelscope.utils.test_utils import test_level @@ -18,7 +18,7 @@ class WordSegmentationTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_by_direct_model_download(self): cache_path = snapshot_download(self.model_id) - tokenizer = TokenClassifcationPreprocessor(cache_path) + tokenizer = TokenClassificationPreprocessor(cache_path) model = SbertForTokenClassification(cache_path, tokenizer=tokenizer) pipeline1 = WordSegmentationPipeline(model, preprocessor=tokenizer) pipeline2 = pipeline( @@ -31,7 +31,7 @@ class WordSegmentationTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id) - tokenizer = TokenClassifcationPreprocessor(model.model_dir) + tokenizer = TokenClassificationPreprocessor(model.model_dir) pipeline_ins = pipeline( task=Tasks.word_segmentation, model=model, preprocessor=tokenizer) print(pipeline_ins(input=self.sentence))