From e288cf076e791ccfd23eb165b21a6fdbeb958abb Mon Sep 17 00:00:00 2001 From: "wenmeng.zwm" Date: Wed, 22 Jun 2022 14:15:32 +0800 Subject: [PATCH] [to #42362853] refactor pipeline and standardize module_name * using get_model to validate hub path * support reading pipeline info from configuration file * add metainfo const * update model type and pipeline type and fix UT * relax requimrent for protobuf * skip two dataset tests due to temporal failure Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9118154 --- modelscope/metainfo.py | 94 +++++++++++++++++++ .../models/audio/tts/am/sambert_hifi_16k.py | 4 +- .../generic_text_to_speech_frontend.py | 3 +- .../models/audio/tts/vocoder/hifigan16k.py | 3 +- modelscope/models/base.py | 19 +++- .../multi_model/image_captioning_model.py | 4 +- .../nlp/bert_for_sequence_classification.py | 4 +- .../models/nlp/palm_for_text_generation.py | 3 +- .../nlp/sbert_for_sentence_similarity.py | 4 +- .../nlp/sbert_for_token_classification.py | 5 +- .../pipelines/audio/linear_aec_pipeline.py | 4 +- .../audio/text_to_speech_pipeline.py | 3 +- modelscope/pipelines/base.py | 12 +-- modelscope/pipelines/builder.py | 65 ++++++++----- .../pipelines/cv/image_cartoon_pipeline.py | 3 +- .../pipelines/cv/image_matting_pipeline.py | 3 +- .../pipelines/cv/ocr_detection_pipeline.py | 3 +- .../multi_modal/image_captioning_pipeline.py | 4 +- .../nlp/sentence_similarity_pipeline.py | 4 +- .../nlp/sequence_classification_pipeline.py | 3 +- .../pipelines/nlp/text_generation_pipeline.py | 4 +- .../nlp/word_segmentation_pipeline.py | 4 +- modelscope/pipelines/util.py | 53 +++++++++-- modelscope/preprocessors/image.py | 3 +- modelscope/preprocessors/multi_model.py | 3 +- modelscope/preprocessors/nlp.py | 8 +- modelscope/preprocessors/text_to_speech.py | 5 +- modelscope/utils/hub.py | 40 +++++++- requirements/audio.txt | 10 +- tests/pipelines/test_image_matting.py | 2 +- tests/pipelines/test_speech_signal_process.py | 3 +- tests/pipelines/test_text_classification.py | 25 ----- tests/pipelines/test_text_to_speech.py | 5 +- tests/preprocessors/test_text_to_speech.py | 3 +- tests/pydatasets/test_py_dataset.py | 2 + 35 files changed, 303 insertions(+), 114 deletions(-) create mode 100644 modelscope/metainfo.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py new file mode 100644 index 00000000..63af2ec4 --- /dev/null +++ b/modelscope/metainfo.py @@ -0,0 +1,94 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + + +class Models(object): + """ Names for different models. + + Holds the standard model name to use for identifying different model. + This should be used to register models. + + Model name should only contain model info but not task info. + """ + # vision models + + # nlp models + bert = 'bert' + palm2_0 = 'palm2.0' + structbert = 'structbert' + + # audio models + sambert_hifi_16k = 'sambert-hifi-16k' + generic_tts_frontend = 'generic-tts-frontend' + hifigan16k = 'hifigan16k' + + # multi-modal models + ofa = 'ofa' + + +class Pipelines(object): + """ Names for different pipelines. + + Holds the standard pipline name to use for identifying different pipeline. + This should be used to register pipelines. + + For pipeline which support different models and implements the common function, we + should use task name for this pipeline. + For pipeline which suuport only one model, we should use ${Model}-${Task} as its name. + """ + # vision tasks + image_matting = 'unet-image-matting' + person_image_cartoon = 'unet-person-image-cartoon' + ocr_detection = 'resnet18-ocr-detection' + + # nlp tasks + sentence_similarity = 'sentence-similarity' + word_segmentation = 'word-segmentation' + text_generation = 'text-generation' + sentiment_analysis = 'sentiment-analysis' + + # audio tasks + sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' + speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' + + # multi-modal tasks + image_caption = 'image-caption' + + +class Trainers(object): + """ Names for different trainer. + + Holds the standard trainer name to use for identifying different trainer. + This should be used to register trainers. + + For a general Trainer, you can use easynlp-trainer/ofa-trainer/sofa-trainer. + For a model specific Trainer, you can use ${ModelName}-${Task}-trainer. + """ + + default = 'Trainer' + + +class Preprocessors(object): + """ Names for different preprocessor. + + Holds the standard preprocessor name to use for identifying different preprocessor. + This should be used to register preprocessors. + + For a general preprocessor, just use the function name as preprocessor name such as + resize-image, random-crop + For a model-specific preprocessor, use ${modelname}-${fuction} + """ + + # cv preprocessor + load_image = 'load-image' + + # nlp preprocessor + bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' + palm_text_gen_tokenizer = 'palm-text-gen-tokenizer' + sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' + + # audio preprocessor + linear_aec_fbank = 'linear-aec-fbank' + text_to_tacotron_symbols = 'text-to-tacotron-symbols' + + # multi-modal + ofa_image_caption = 'ofa-image-caption' diff --git a/modelscope/models/audio/tts/am/sambert_hifi_16k.py b/modelscope/models/audio/tts/am/sambert_hifi_16k.py index 2db9abc6..415e88b3 100644 --- a/modelscope/models/audio/tts/am/sambert_hifi_16k.py +++ b/modelscope/models/audio/tts/am/sambert_hifi_16k.py @@ -6,6 +6,7 @@ import numpy as np import tensorflow as tf from sklearn.preprocessing import MultiLabelBinarizer +from modelscope.metainfo import Models from modelscope.models.base import Model from modelscope.models.builder import MODELS from modelscope.utils.constant import ModelFile, Tasks @@ -26,7 +27,8 @@ def multi_label_symbol_to_sequence(my_classes, my_symbol): return one_hot.fit_transform(sequences) -@MODELS.register_module(Tasks.text_to_speech, module_name=r'sambert_hifi_16k') +@MODELS.register_module( + Tasks.text_to_speech, module_name=Models.sambert_hifi_16k) class SambertNetHifi16k(Model): def __init__(self, diff --git a/modelscope/models/audio/tts/frontend/generic_text_to_speech_frontend.py b/modelscope/models/audio/tts/frontend/generic_text_to_speech_frontend.py index c6aabf75..9f13f36f 100644 --- a/modelscope/models/audio/tts/frontend/generic_text_to_speech_frontend.py +++ b/modelscope/models/audio/tts/frontend/generic_text_to_speech_frontend.py @@ -2,6 +2,7 @@ import os import zipfile from typing import Any, Dict, List +from modelscope.metainfo import Models from modelscope.models.base import Model from modelscope.models.builder import MODELS from modelscope.utils.audio.tts_exceptions import ( @@ -13,7 +14,7 @@ __all__ = ['GenericTtsFrontend'] @MODELS.register_module( - Tasks.text_to_speech, module_name=r'generic_tts_frontend') + Tasks.text_to_speech, module_name=Models.generic_tts_frontend) class GenericTtsFrontend(Model): def __init__(self, model_dir='.', lang_type='pinyin', *args, **kwargs): diff --git a/modelscope/models/audio/tts/vocoder/hifigan16k.py b/modelscope/models/audio/tts/vocoder/hifigan16k.py index 0d917dbe..b3fd9cf6 100644 --- a/modelscope/models/audio/tts/vocoder/hifigan16k.py +++ b/modelscope/models/audio/tts/vocoder/hifigan16k.py @@ -10,6 +10,7 @@ import numpy as np import torch from scipy.io.wavfile import write +from modelscope.metainfo import Models from modelscope.models.base import Model from modelscope.models.builder import MODELS from modelscope.utils.audio.tts_exceptions import \ @@ -36,7 +37,7 @@ class AttrDict(dict): self.__dict__ = self -@MODELS.register_module(Tasks.text_to_speech, module_name=r'hifigan16k') +@MODELS.register_module(Tasks.text_to_speech, module_name=Models.hifigan16k) class Hifigan16k(Model): def __init__(self, model_dir, *args, **kwargs): diff --git a/modelscope/models/base.py b/modelscope/models/base.py index 99309a7e..cb6d2b0e 100644 --- a/modelscope/models/base.py +++ b/modelscope/models/base.py @@ -8,6 +8,9 @@ from modelscope.hub.snapshot_download import snapshot_download from modelscope.models.builder import build_model from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger + +logger = get_logger() Tensor = Union['torch.Tensor', 'tf.Tensor'] @@ -46,18 +49,24 @@ class Model(ABC): local_model_dir = model_name_or_path else: local_model_dir = snapshot_download(model_name_or_path) - # else: - # raise ValueError( - # 'Remote model repo {model_name_or_path} does not exists') - + logger.info(f'initialize model from {local_model_dir}') cfg = Config.from_file( osp.join(local_model_dir, ModelFile.CONFIGURATION)) task_name = cfg.task model_cfg = cfg.model + assert hasattr( + cfg, 'pipeline'), 'pipeline config is missing from config file.' + pipeline_cfg = cfg.pipeline # TODO @wenmeng.zwm may should manually initialize model after model building if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): model_cfg.type = model_cfg.model_type + model_cfg.model_dir = local_model_dir + for k, v in kwargs.items(): model_cfg.k = v - return build_model(model_cfg, task_name) + model = build_model(model_cfg, task_name) + + # dynamically add pipeline info to model for pipeline inference + model.pipeline = pipeline_cfg + return model diff --git a/modelscope/models/multi_model/image_captioning_model.py b/modelscope/models/multi_model/image_captioning_model.py index fad0663e..79ab2b5f 100644 --- a/modelscope/models/multi_model/image_captioning_model.py +++ b/modelscope/models/multi_model/image_captioning_model.py @@ -3,6 +3,7 @@ from typing import Any, Dict from PIL import Image +from modelscope.metainfo import Models from modelscope.utils.constant import ModelFile, Tasks from ..base import Model from ..builder import MODELS @@ -10,8 +11,7 @@ from ..builder import MODELS __all__ = ['OfaForImageCaptioning'] -@MODELS.register_module( - Tasks.image_captioning, module_name=r'ofa-image-captioning') +@MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) class OfaForImageCaptioning(Model): def __init__(self, model_dir, *args, **kwargs): diff --git a/modelscope/models/nlp/bert_for_sequence_classification.py b/modelscope/models/nlp/bert_for_sequence_classification.py index a3cc4b68..7d85fa28 100644 --- a/modelscope/models/nlp/bert_for_sequence_classification.py +++ b/modelscope/models/nlp/bert_for_sequence_classification.py @@ -4,6 +4,7 @@ from typing import Any, Dict import json import numpy as np +from modelscope.metainfo import Models from modelscope.utils.constant import Tasks from ..base import Model from ..builder import MODELS @@ -11,8 +12,7 @@ from ..builder import MODELS __all__ = ['BertForSequenceClassification'] -@MODELS.register_module( - Tasks.text_classification, module_name=r'bert-sentiment-analysis') +@MODELS.register_module(Tasks.text_classification, module_name=Models.bert) class BertForSequenceClassification(Model): def __init__(self, model_dir: str, *args, **kwargs): diff --git a/modelscope/models/nlp/palm_for_text_generation.py b/modelscope/models/nlp/palm_for_text_generation.py index e5799feb..f4518d4f 100644 --- a/modelscope/models/nlp/palm_for_text_generation.py +++ b/modelscope/models/nlp/palm_for_text_generation.py @@ -1,5 +1,6 @@ from typing import Dict +from modelscope.metainfo import Models from modelscope.utils.constant import Tasks from ..base import Model, Tensor from ..builder import MODELS @@ -7,7 +8,7 @@ from ..builder import MODELS __all__ = ['PalmForTextGeneration'] -@MODELS.register_module(Tasks.text_generation, module_name=r'palm2.0') +@MODELS.register_module(Tasks.text_generation, module_name=Models.palm2_0) class PalmForTextGeneration(Model): def __init__(self, model_dir: str, *args, **kwargs): diff --git a/modelscope/models/nlp/sbert_for_sentence_similarity.py b/modelscope/models/nlp/sbert_for_sentence_similarity.py index 98daac92..cbcef1ce 100644 --- a/modelscope/models/nlp/sbert_for_sentence_similarity.py +++ b/modelscope/models/nlp/sbert_for_sentence_similarity.py @@ -8,6 +8,7 @@ from sofa import SbertModel from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel from torch import nn +from modelscope.metainfo import Models from modelscope.utils.constant import Tasks from ..base import Model, Tensor from ..builder import MODELS @@ -38,8 +39,7 @@ class SbertTextClassifier(SbertPreTrainedModel): @MODELS.register_module( - Tasks.sentence_similarity, - module_name=r'sbert-base-chinese-sentence-similarity') + Tasks.sentence_similarity, module_name=Models.structbert) class SbertForSentenceSimilarity(Model): def __init__(self, model_dir: str, *args, **kwargs): diff --git a/modelscope/models/nlp/sbert_for_token_classification.py b/modelscope/models/nlp/sbert_for_token_classification.py index b918dc37..fdf5afaf 100644 --- a/modelscope/models/nlp/sbert_for_token_classification.py +++ b/modelscope/models/nlp/sbert_for_token_classification.py @@ -4,6 +4,7 @@ import numpy as np import torch from sofa import SbertConfig, SbertForTokenClassification +from modelscope.metainfo import Models from modelscope.utils.constant import Tasks from ..base import Model, Tensor from ..builder import MODELS @@ -11,9 +12,7 @@ from ..builder import MODELS __all__ = ['StructBertForTokenClassification'] -@MODELS.register_module( - Tasks.word_segmentation, - module_name=r'structbert-chinese-word-segmentation') +@MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert) class StructBertForTokenClassification(Model): def __init__(self, model_dir: str, *args, **kwargs): diff --git a/modelscope/pipelines/audio/linear_aec_pipeline.py b/modelscope/pipelines/audio/linear_aec_pipeline.py index 528d8d47..70562b19 100644 --- a/modelscope/pipelines/audio/linear_aec_pipeline.py +++ b/modelscope/pipelines/audio/linear_aec_pipeline.py @@ -7,6 +7,7 @@ import scipy.io.wavfile as wav import torch import yaml +from modelscope.metainfo import Pipelines from modelscope.preprocessors.audio import LinearAECAndFbank from modelscope.utils.constant import ModelFile, Tasks from ..base import Pipeline @@ -39,7 +40,8 @@ def initialize_config(module_cfg): @PIPELINES.register_module( - Tasks.speech_signal_process, module_name=r'speech_dfsmn_aec_psm_16k') + Tasks.speech_signal_process, + module_name=Pipelines.speech_dfsmn_aec_psm_16k) class LinearAECPipeline(Pipeline): r"""AEC Inference Pipeline only support 16000 sample rate. diff --git a/modelscope/pipelines/audio/text_to_speech_pipeline.py b/modelscope/pipelines/audio/text_to_speech_pipeline.py index ecd9daac..22586d3e 100644 --- a/modelscope/pipelines/audio/text_to_speech_pipeline.py +++ b/modelscope/pipelines/audio/text_to_speech_pipeline.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List import numpy as np +from modelscope.metainfo import Pipelines from modelscope.models import Model from modelscope.models.audio.tts.am import SambertNetHifi16k from modelscope.models.audio.tts.vocoder import Hifigan16k @@ -15,7 +16,7 @@ __all__ = ['TextToSpeechSambertHifigan16kPipeline'] @PIPELINES.register_module( - Tasks.text_to_speech, module_name=r'tts-sambert-hifigan-16k') + Tasks.text_to_speech, module_name=Pipelines.sambert_hifigan_16k_tts) class TextToSpeechSambertHifigan16kPipeline(Pipeline): def __init__(self, diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 59bd298b..7e32f543 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -11,7 +11,7 @@ from modelscope.pydatasets import PyDataset from modelscope.utils.config import Config from modelscope.utils.logger import get_logger from .outputs import TASK_OUTPUTS -from .util import is_model_name +from .util import is_model, is_official_hub_path Tensor = Union['torch.Tensor', 'tf.Tensor'] Input = Union[str, tuple, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] @@ -27,12 +27,10 @@ class Pipeline(ABC): def initiate_single_model(self, model): logger.info(f'initiate model from {model}') - # TODO @wenmeng.zwm replace model.startswith('damo/') with get_model - if isinstance(model, str) and model.startswith('damo/'): - if not osp.exists(model): - model = snapshot_download(model) - return Model.from_pretrained(model) if is_model_name( - model) else model + if isinstance(model, str) and is_official_hub_path(model): + model = snapshot_download( + model) if not osp.exists(model) else model + return Model.from_pretrained(model) if is_model(model) else model elif isinstance(model, Model): return model else: diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 5e1fbd87..90d613f8 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -3,32 +3,39 @@ import os.path as osp from typing import List, Union +from attr import has + +from modelscope.metainfo import Pipelines from modelscope.models.base import Model from modelscope.utils.config import Config, ConfigDict -from modelscope.utils.constant import Tasks +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.hub import read_config from modelscope.utils.registry import Registry, build_from_cfg from .base import Pipeline +from .util import is_official_hub_path PIPELINES = Registry('pipelines') DEFAULT_MODEL_FOR_PIPELINE = { # TaskName: (pipeline_module_name, model_repo) Tasks.word_segmentation: - ('structbert-chinese-word-segmentation', + (Pipelines.word_segmentation, 'damo/nlp_structbert_word-segmentation_chinese-base'), Tasks.sentence_similarity: - ('sbert-base-chinese-sentence-similarity', + (Pipelines.sentence_similarity, 'damo/nlp_structbert_sentence-similarity_chinese-base'), - Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'), - Tasks.text_classification: - ('bert-sentiment-analysis', 'damo/bert-base-sst2'), - Tasks.text_generation: ('palm2.0', + Tasks.image_matting: + (Pipelines.image_matting, 'damo/cv_unet_image-matting'), + Tasks.text_classification: (Pipelines.sentiment_analysis, + 'damo/bert-base-sst2'), + Tasks.text_generation: (Pipelines.text_generation, 'damo/nlp_palm2.0_text-generation_chinese-base'), - Tasks.image_captioning: ('ofa', 'damo/ofa_image-caption_coco_large_en'), + Tasks.image_captioning: (Pipelines.image_caption, + 'damo/ofa_image-caption_coco_large_en'), Tasks.image_generation: - ('person-image-cartoon', + (Pipelines.person_image_cartoon, 'damo/cv_unet_person-image-cartoon_compound-models'), - Tasks.ocr_detection: ('ocr-detection', + Tasks.ocr_detection: (Pipelines.ocr_detection, 'damo/cv_resnet18_ocr-detection-line-level_damo'), } @@ -86,30 +93,40 @@ def pipeline(task: str = None, if task is None and pipeline_name is None: raise ValueError('task or pipeline_name is required') + assert isinstance(model, (type(None), str, Model, list)), \ + f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' + if pipeline_name is None: # get default pipeline for this task if isinstance(model, str) \ or (isinstance(model, list) and isinstance(model[0], str)): - - # if is_model_name(model): - if (isinstance(model, str) and model.startswith('damo/')) \ - or (isinstance(model, list) and model[0].startswith('damo/')) \ - or (isinstance(model, str) and osp.exists(model)): - # TODO @wenmeng.zwm add support when model is a str of modelhub address - # read pipeline info from modelhub configuration file. - pipeline_name, default_model_repo = get_default_pipeline_info( - task) + if is_official_hub_path(model): + # read config file from hub and parse + cfg = read_config(model) if isinstance( + model, str) else read_config(model[0]) + assert hasattr( + cfg, + 'pipeline'), 'pipeline config is missing from config file.' + pipeline_name = cfg.pipeline.type else: + # used for test case, when model is str and is not hub path pipeline_name = get_pipeline_by_model_name(task, model) + elif isinstance(model, Model) or \ + (isinstance(model, list) and isinstance(model[0], Model)): + # get pipeline info from Model object + first_model = model[0] if isinstance(model, list) else model + if not hasattr(first_model, 'pipeline'): + # model is instantiated by user, we should parse config again + cfg = read_config(first_model.model_dir) + assert hasattr( + cfg, + 'pipeline'), 'pipeline config is missing from config file.' + first_model.pipeline = cfg.pipeline + pipeline_name = first_model.pipeline.type else: pipeline_name, default_model_repo = get_default_pipeline_info(task) - - if model is None: model = default_model_repo - assert isinstance(model, (type(None), str, Model, list)), \ - f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' - cfg = ConfigDict(type=pipeline_name, model=model) if kwargs: diff --git a/modelscope/pipelines/cv/image_cartoon_pipeline.py b/modelscope/pipelines/cv/image_cartoon_pipeline.py index d253eaf5..717336e9 100644 --- a/modelscope/pipelines/cv/image_cartoon_pipeline.py +++ b/modelscope/pipelines/cv/image_cartoon_pipeline.py @@ -6,6 +6,7 @@ import numpy as np import PIL import tensorflow as tf +from modelscope.metainfo import Pipelines from modelscope.models.cv.cartoon.facelib.facer import FaceAna from modelscope.models.cv.cartoon.mtcnn_pytorch.src.align_trans import ( get_reference_facial_points, warp_and_crop_face) @@ -25,7 +26,7 @@ logger = get_logger() @PIPELINES.register_module( - Tasks.image_generation, module_name='person-image-cartoon') + Tasks.image_generation, module_name=Pipelines.person_image_cartoon) class ImageCartoonPipeline(Pipeline): def __init__(self, model: str): diff --git a/modelscope/pipelines/cv/image_matting_pipeline.py b/modelscope/pipelines/cv/image_matting_pipeline.py index 0c60dfa7..b3e27e4b 100644 --- a/modelscope/pipelines/cv/image_matting_pipeline.py +++ b/modelscope/pipelines/cv/image_matting_pipeline.py @@ -5,6 +5,7 @@ import cv2 import numpy as np import PIL +from modelscope.metainfo import Pipelines from modelscope.pipelines.base import Input from modelscope.preprocessors import load_image from modelscope.utils.constant import ModelFile, Tasks @@ -16,7 +17,7 @@ logger = get_logger() @PIPELINES.register_module( - Tasks.image_matting, module_name=Tasks.image_matting) + Tasks.image_matting, module_name=Pipelines.image_matting) class ImageMattingPipeline(Pipeline): def __init__(self, model: str): diff --git a/modelscope/pipelines/cv/ocr_detection_pipeline.py b/modelscope/pipelines/cv/ocr_detection_pipeline.py index 9728e441..0502fe36 100644 --- a/modelscope/pipelines/cv/ocr_detection_pipeline.py +++ b/modelscope/pipelines/cv/ocr_detection_pipeline.py @@ -10,6 +10,7 @@ import PIL import tensorflow as tf import tf_slim as slim +from modelscope.metainfo import Pipelines from modelscope.pipelines.base import Input from modelscope.preprocessors import load_image from modelscope.utils.constant import ModelFile, Tasks @@ -38,7 +39,7 @@ tf.app.flags.DEFINE_float('link_threshold', 0.6, @PIPELINES.register_module( - Tasks.ocr_detection, module_name=Tasks.ocr_detection) + Tasks.ocr_detection, module_name=Pipelines.ocr_detection) class OCRDetectionPipeline(Pipeline): def __init__(self, model: str): diff --git a/modelscope/pipelines/multi_modal/image_captioning_pipeline.py b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py index f0b1f53c..9f32caf4 100644 --- a/modelscope/pipelines/multi_modal/image_captioning_pipeline.py +++ b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py @@ -1,5 +1,6 @@ from typing import Any, Dict, Union +from modelscope.metainfo import Pipelines from modelscope.preprocessors import OfaImageCaptionPreprocessor, Preprocessor from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger @@ -9,7 +10,8 @@ from ..builder import PIPELINES logger = get_logger() -@PIPELINES.register_module(Tasks.image_captioning, module_name='ofa') +@PIPELINES.register_module( + Tasks.image_captioning, module_name=Pipelines.image_caption) class ImageCaptionPipeline(Pipeline): def __init__(self, diff --git a/modelscope/pipelines/nlp/sentence_similarity_pipeline.py b/modelscope/pipelines/nlp/sentence_similarity_pipeline.py index 1b630c10..71df86e2 100644 --- a/modelscope/pipelines/nlp/sentence_similarity_pipeline.py +++ b/modelscope/pipelines/nlp/sentence_similarity_pipeline.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Union import numpy as np +from modelscope.metainfo import Pipelines from modelscope.models.nlp import SbertForSentenceSimilarity from modelscope.preprocessors import SequenceClassificationPreprocessor from modelscope.utils.constant import Tasks @@ -13,8 +14,7 @@ __all__ = ['SentenceSimilarityPipeline'] @PIPELINES.register_module( - Tasks.sentence_similarity, - module_name=r'sbert-base-chinese-sentence-similarity') + Tasks.sentence_similarity, module_name=Pipelines.sentence_similarity) class SentenceSimilarityPipeline(Pipeline): def __init__(self, diff --git a/modelscope/pipelines/nlp/sequence_classification_pipeline.py b/modelscope/pipelines/nlp/sequence_classification_pipeline.py index 1dbe2efd..43c81d60 100644 --- a/modelscope/pipelines/nlp/sequence_classification_pipeline.py +++ b/modelscope/pipelines/nlp/sequence_classification_pipeline.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Union import numpy as np +from modelscope.metainfo import Pipelines from modelscope.models.nlp import BertForSequenceClassification from modelscope.preprocessors import SequenceClassificationPreprocessor from modelscope.utils.constant import Tasks @@ -13,7 +14,7 @@ __all__ = ['SequenceClassificationPipeline'] @PIPELINES.register_module( - Tasks.text_classification, module_name=r'bert-sentiment-analysis') + Tasks.text_classification, module_name=Pipelines.sentiment_analysis) class SequenceClassificationPipeline(Pipeline): def __init__(self, diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py index 881e7ea6..ebd4be8e 100644 --- a/modelscope/pipelines/nlp/text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -1,5 +1,6 @@ from typing import Dict, Optional, Union +from modelscope.metainfo import Pipelines from modelscope.models import Model from modelscope.models.nlp import PalmForTextGeneration from modelscope.preprocessors import TextGenerationPreprocessor @@ -10,7 +11,8 @@ from ..builder import PIPELINES __all__ = ['TextGenerationPipeline'] -@PIPELINES.register_module(Tasks.text_generation, module_name=r'palm2.0') +@PIPELINES.register_module( + Tasks.text_generation, module_name=Pipelines.text_generation) class TextGenerationPipeline(Pipeline): def __init__(self, diff --git a/modelscope/pipelines/nlp/word_segmentation_pipeline.py b/modelscope/pipelines/nlp/word_segmentation_pipeline.py index 1cc08a38..a45dafc3 100644 --- a/modelscope/pipelines/nlp/word_segmentation_pipeline.py +++ b/modelscope/pipelines/nlp/word_segmentation_pipeline.py @@ -1,5 +1,6 @@ from typing import Any, Dict, Optional, Union +from modelscope.metainfo import Pipelines from modelscope.models import Model from modelscope.models.nlp import StructBertForTokenClassification from modelscope.preprocessors import TokenClassifcationPreprocessor @@ -11,8 +12,7 @@ __all__ = ['WordSegmentationPipeline'] @PIPELINES.register_module( - Tasks.word_segmentation, - module_name=r'structbert-chinese-word-segmentation') + Tasks.word_segmentation, module_name=Pipelines.word_segmentation) class WordSegmentationPipeline(Pipeline): def __init__(self, diff --git a/modelscope/pipelines/util.py b/modelscope/pipelines/util.py index 6fe6e9fd..d034a7d4 100644 --- a/modelscope/pipelines/util.py +++ b/modelscope/pipelines/util.py @@ -2,6 +2,7 @@ import os.path as osp from typing import List, Union +from modelscope.hub.api import HubApi from modelscope.hub.file_download import model_file_download from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile @@ -19,31 +20,63 @@ def is_config_has_model(cfg_file): return False -def is_model_name(model: Union[str, List]): - """ whether model is a valid modelhub path +def is_official_hub_path(path: Union[str, List]): + """ Whether path is a official hub name or a valid local + path to official hub directory. """ - def is_model_name_impl(model): - if osp.exists(model): - cfg_file = osp.join(model, ModelFile.CONFIGURATION) + def is_official_hub_impl(path): + if osp.exists(path): + cfg_file = osp.join(path, ModelFile.CONFIGURATION) + return osp.exists(cfg_file) + else: + try: + _ = HubApi().get_model(path) + return True + except Exception: + return False + + if isinstance(path, str): + return is_official_hub_impl(path) + else: + results = [is_official_hub_impl(m) for m in path] + all_true = all(results) + any_true = any(results) + if any_true and not all_true: + raise ValueError( + f'some model are hub address, some are not, model list: {path}' + ) + + return all_true + + +def is_model(path: Union[str, List]): + """ whether path is a valid modelhub path and containing model config + """ + + def is_modelhub_path_impl(path): + if osp.exists(path): + cfg_file = osp.join(path, ModelFile.CONFIGURATION) if osp.exists(cfg_file): return is_config_has_model(cfg_file) else: return False else: try: - cfg_file = model_file_download(model, ModelFile.CONFIGURATION) + cfg_file = model_file_download(path, ModelFile.CONFIGURATION) return is_config_has_model(cfg_file) except Exception: return False - if isinstance(model, str): - return is_model_name_impl(model) + if isinstance(path, str): + return is_modelhub_path_impl(path) else: - results = [is_model_name_impl(m) for m in model] + results = [is_modelhub_path_impl(m) for m in path] all_true = all(results) any_true = any(results) if any_true and not all_true: - raise ValueError('some model are hub address, some are not') + raise ValueError( + f'some models are hub address, some are not, model list: {path}' + ) return all_true diff --git a/modelscope/preprocessors/image.py b/modelscope/preprocessors/image.py index 6bd8aed5..b2123fb7 100644 --- a/modelscope/preprocessors/image.py +++ b/modelscope/preprocessors/image.py @@ -5,11 +5,12 @@ from typing import Dict, Union from PIL import Image, ImageOps from modelscope.fileio import File +from modelscope.metainfo import Preprocessors from modelscope.utils.constant import Fields from .builder import PREPROCESSORS -@PREPROCESSORS.register_module(Fields.cv) +@PREPROCESSORS.register_module(Fields.cv, Preprocessors.load_image) class LoadImage: """Load an image from file or url. Added or updated keys are "filename", "img", "img_shape", diff --git a/modelscope/preprocessors/multi_model.py b/modelscope/preprocessors/multi_model.py index ea2e7493..aa0bc8a7 100644 --- a/modelscope/preprocessors/multi_model.py +++ b/modelscope/preprocessors/multi_model.py @@ -7,6 +7,7 @@ import torch from PIL import Image from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Preprocessors from modelscope.utils.constant import Fields, ModelFile from modelscope.utils.type_assert import type_assert from .base import Preprocessor @@ -19,7 +20,7 @@ __all__ = [ @PREPROCESSORS.register_module( - Fields.multi_modal, module_name=r'ofa-image-caption') + Fields.multi_modal, module_name=Preprocessors.ofa_image_caption) class OfaImageCaptionPreprocessor(Preprocessor): def __init__(self, model_dir: str, *args, **kwargs): diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 0abb01cc..7a47a866 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Union from transformers import AutoTokenizer +from modelscope.metainfo import Preprocessors from modelscope.utils.constant import Fields, InputFields from modelscope.utils.type_assert import type_assert from .base import Preprocessor @@ -31,7 +32,7 @@ class Tokenize(Preprocessor): @PREPROCESSORS.register_module( - Fields.nlp, module_name=r'bert-sequence-classification') + Fields.nlp, module_name=Preprocessors.bert_seq_cls_tokenizer) class SequenceClassificationPreprocessor(Preprocessor): def __init__(self, model_dir: str, *args, **kwargs): @@ -124,7 +125,8 @@ class SequenceClassificationPreprocessor(Preprocessor): return rst -@PREPROCESSORS.register_module(Fields.nlp, module_name=r'palm2.0') +@PREPROCESSORS.register_module( + Fields.nlp, module_name=Preprocessors.palm_text_gen_tokenizer) class TextGenerationPreprocessor(Preprocessor): def __init__(self, model_dir: str, tokenizer, *args, **kwargs): @@ -180,7 +182,7 @@ class TextGenerationPreprocessor(Preprocessor): @PREPROCESSORS.register_module( - Fields.nlp, module_name=r'bert-token-classification') + Fields.nlp, module_name=Preprocessors.sbert_token_cls_tokenizer) class TokenClassifcationPreprocessor(Preprocessor): def __init__(self, model_dir: str, *args, **kwargs): diff --git a/modelscope/preprocessors/text_to_speech.py b/modelscope/preprocessors/text_to_speech.py index 8b8dae14..9d8af6fa 100644 --- a/modelscope/preprocessors/text_to_speech.py +++ b/modelscope/preprocessors/text_to_speech.py @@ -3,6 +3,7 @@ import io from typing import Any, Dict, Union from modelscope.fileio import File +from modelscope.metainfo import Preprocessors from modelscope.models.audio.tts.frontend import GenericTtsFrontend from modelscope.models.base import Model from modelscope.utils.audio.tts_exceptions import * # noqa F403 @@ -10,11 +11,11 @@ from modelscope.utils.constant import Fields from .base import Preprocessor from .builder import PREPROCESSORS -__all__ = ['TextToTacotronSymbols', 'text_to_tacotron_symbols'] +__all__ = ['TextToTacotronSymbols'] @PREPROCESSORS.register_module( - Fields.audio, module_name=r'text_to_tacotron_symbols') + Fields.audio, module_name=Preprocessors.text_to_tacotron_symbols) class TextToTacotronSymbols(Preprocessor): """extract tacotron symbols from text. diff --git a/modelscope/utils/hub.py b/modelscope/utils/hub.py index 245642d1..01a1b1b0 100644 --- a/modelscope/utils/hub.py +++ b/modelscope/utils/hub.py @@ -1,11 +1,49 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import os.path as osp +from typing import List, Union -from modelscope.hub.constants import MODEL_ID_SEPARATOR +from numpy import deprecate + +from modelscope.hub.file_download import model_file_download +from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.utils.utils import get_cache_dir +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile # temp solution before the hub-cache is in place +@deprecate def get_model_cache_dir(model_id: str): return os.path.join(get_cache_dir(), model_id) + + +def read_config(model_id_or_path: str): + """ Read config from hub or local path + + Args: + model_id_or_path (str): Model repo name or local directory path. + + Return: + config (:obj:`Config`): config object + """ + if not os.path.exists(model_id_or_path): + local_path = model_file_download(model_id_or_path, + ModelFile.CONFIGURATION) + else: + local_path = os.path.join(model_id_or_path, ModelFile.CONFIGURATION) + + return Config.from_file(local_path) + + +def auto_load(model: Union[str, List[str]]): + if isinstance(model, str): + if not osp.exists(model): + model = snapshot_download(model) + else: + model = [ + snapshot_download(m) if not osp.exists(m) else m for m in model + ] + + return model diff --git a/requirements/audio.txt b/requirements/audio.txt index 140836a8..3b625261 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -1,10 +1,10 @@ #tts h5py==2.10.0 -#https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp36-cp36m-linux_x86_64.whl -https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp37-cp37m-linux_x86_64.whl +https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp36-cp36m-linux_x86_64.whl; python_version=='3.6' +https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp37-cp37m-linux_x86_64.whl; python_version=='3.7' +https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp38-cp38-linux_x86_64.whl; python_version=='3.8' +https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp39-cp39-linux_x86_64.whl; python_version=='3.9' https://swap.oss-cn-hangzhou.aliyuncs.com/Jiaqi%2Fmaas%2Ftts%2Frequirements%2Fpytorch_wavelets-1.3.0-py3-none-any.whl?Expires=1685688388&OSSAccessKeyId=LTAI4Ffebq4d9jTVDwiSbY4L&Signature=jcQbg5EZ%2Bdys3%2F4BRn3srrKLdIg%3D -#https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp38-cp38-linux_x86_64.whl -#https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp39-cp39-linux_x86_64.whl inflect keras==2.2.4 librosa @@ -12,7 +12,7 @@ lxml matplotlib nara_wpe numpy==1.18.* -protobuf==3.20.* +protobuf>3,<=3.20 ptflops PyWavelets>=1.0.0 scikit-learn==0.23.2 diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py index 751b6975..23ea678b 100644 --- a/tests/pipelines/test_image_matting.py +++ b/tests/pipelines/test_image_matting.py @@ -60,7 +60,7 @@ class ImageMattingTest(unittest.TestCase): cv2.imwrite('result.png', result['output_png']) print(f'Output written to {osp.abspath("result.png")}') - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_modelscope_dataset(self): dataset = PyDataset.load('beans', split='train', target='image') img_matting = pipeline(Tasks.image_matting, model=self.model_id) diff --git a/tests/pipelines/test_speech_signal_process.py b/tests/pipelines/test_speech_signal_process.py index f1369a2f..23939f8e 100644 --- a/tests/pipelines/test_speech_signal_process.py +++ b/tests/pipelines/test_speech_signal_process.py @@ -3,6 +3,7 @@ import shutil import unittest from modelscope.fileio import File +from modelscope.metainfo import Pipelines from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks @@ -42,7 +43,7 @@ class SpeechSignalProcessTest(unittest.TestCase): aec = pipeline( Tasks.speech_signal_process, model=self.model_id, - pipeline_name=r'speech_dfsmn_aec_psm_16k') + pipeline_name=Pipelines.speech_dfsmn_aec_psm_16k) aec(input, output_path='output.wav') diff --git a/tests/pipelines/test_text_classification.py b/tests/pipelines/test_text_classification.py index 8ecd9ed4..2581c220 100644 --- a/tests/pipelines/test_text_classification.py +++ b/tests/pipelines/test_text_classification.py @@ -38,31 +38,6 @@ class SequenceClassificationTest(unittest.TestCase): break print(r) - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_run(self): - model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ - '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' - cache_path_str = r'.cache/easynlp/bert-base-sst2.zip' - cache_path = Path(cache_path_str) - - if not cache_path.exists(): - cache_path.parent.mkdir(parents=True, exist_ok=True) - cache_path.touch(exist_ok=True) - with cache_path.open('wb') as ofile: - ofile.write(File.read(model_url)) - - with zipfile.ZipFile(cache_path_str, 'r') as zipf: - zipf.extractall(cache_path.parent) - path = r'.cache/easynlp/' - model = BertForSequenceClassification(path) - preprocessor = SequenceClassificationPreprocessor( - path, first_sequence='sentence', second_sequence=None) - pipeline1 = SequenceClassificationPipeline(model, preprocessor) - self.predict(pipeline1) - pipeline2 = pipeline( - Tasks.text_classification, model=model, preprocessor=preprocessor) - print(pipeline2('Hello world!')) - @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_from_modelhub(self): model = Model.from_pretrained(self.model_id) diff --git a/tests/pipelines/test_text_to_speech.py b/tests/pipelines/test_text_to_speech.py index c9b988a1..0d76cbac 100644 --- a/tests/pipelines/test_text_to_speech.py +++ b/tests/pipelines/test_text_to_speech.py @@ -11,6 +11,7 @@ import torch from scipy.io.wavfile import write from modelscope.fileio import File +from modelscope.metainfo import Pipelines, Preprocessors from modelscope.models import Model, build_model from modelscope.models.audio.tts.am import SambertNetHifi16k from modelscope.models.audio.tts.vocoder import AttrDict, Hifigan16k @@ -32,7 +33,7 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase): voc_model_id = 'damo/speech_hifigan16k_tts_zhitian_emo' cfg_preprocessor = dict( - type='text_to_tacotron_symbols', + type=Preprocessors.text_to_tacotron_symbols, model_name=preprocessor_model_id, lang_type=lang_type) preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) @@ -45,7 +46,7 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase): self.assertTrue(voc is not None) sambert_tts = pipeline( - pipeline_name='tts-sambert-hifigan-16k', + pipeline_name=Pipelines.sambert_hifigan_16k_tts, config_file='', model=[am, voc], preprocessor=preprocessor) diff --git a/tests/preprocessors/test_text_to_speech.py b/tests/preprocessors/test_text_to_speech.py index 18b66987..fd2473fd 100644 --- a/tests/preprocessors/test_text_to_speech.py +++ b/tests/preprocessors/test_text_to_speech.py @@ -1,6 +1,7 @@ import shutil import unittest +from modelscope.metainfo import Preprocessors from modelscope.preprocessors import build_preprocessor from modelscope.utils.constant import Fields, InputFields from modelscope.utils.logger import get_logger @@ -14,7 +15,7 @@ class TtsPreprocessorTest(unittest.TestCase): lang_type = 'pinyin' text = '今天天气不错,我们去散步吧。' cfg = dict( - type='text_to_tacotron_symbols', + type=Preprocessors.text_to_tacotron_symbols, model_name='damo/speech_binary_tts_frontend_resource', lang_type=lang_type) preprocessor = build_preprocessor(cfg, Fields.audio) diff --git a/tests/pydatasets/test_py_dataset.py b/tests/pydatasets/test_py_dataset.py index 4ad767fa..bc38e369 100644 --- a/tests/pydatasets/test_py_dataset.py +++ b/tests/pydatasets/test_py_dataset.py @@ -33,6 +33,8 @@ class ImgPreprocessor(Preprocessor): class PyDatasetTest(unittest.TestCase): + @unittest.skipUnless(test_level() >= 2, + 'skip test due to dataset api problem') def test_ds_basic(self): ms_ds_full = PyDataset.load('squad') ms_ds_full_hf = hfdata.load_dataset('squad')