* using get_model to validate hub path * support reading pipeline info from configuration file * add metainfo const * update model type and pipeline type and fix UT * relax requimrent for protobuf * skip two dataset tests due to temporal failure Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9118154master
| @@ -0,0 +1,94 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| class Models(object): | |||
| """ Names for different models. | |||
| Holds the standard model name to use for identifying different model. | |||
| This should be used to register models. | |||
| Model name should only contain model info but not task info. | |||
| """ | |||
| # vision models | |||
| # nlp models | |||
| bert = 'bert' | |||
| palm2_0 = 'palm2.0' | |||
| structbert = 'structbert' | |||
| # audio models | |||
| sambert_hifi_16k = 'sambert-hifi-16k' | |||
| generic_tts_frontend = 'generic-tts-frontend' | |||
| hifigan16k = 'hifigan16k' | |||
| # multi-modal models | |||
| ofa = 'ofa' | |||
| class Pipelines(object): | |||
| """ Names for different pipelines. | |||
| Holds the standard pipline name to use for identifying different pipeline. | |||
| This should be used to register pipelines. | |||
| For pipeline which support different models and implements the common function, we | |||
| should use task name for this pipeline. | |||
| For pipeline which suuport only one model, we should use ${Model}-${Task} as its name. | |||
| """ | |||
| # vision tasks | |||
| image_matting = 'unet-image-matting' | |||
| person_image_cartoon = 'unet-person-image-cartoon' | |||
| ocr_detection = 'resnet18-ocr-detection' | |||
| # nlp tasks | |||
| sentence_similarity = 'sentence-similarity' | |||
| word_segmentation = 'word-segmentation' | |||
| text_generation = 'text-generation' | |||
| sentiment_analysis = 'sentiment-analysis' | |||
| # audio tasks | |||
| sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' | |||
| speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' | |||
| # multi-modal tasks | |||
| image_caption = 'image-caption' | |||
| class Trainers(object): | |||
| """ Names for different trainer. | |||
| Holds the standard trainer name to use for identifying different trainer. | |||
| This should be used to register trainers. | |||
| For a general Trainer, you can use easynlp-trainer/ofa-trainer/sofa-trainer. | |||
| For a model specific Trainer, you can use ${ModelName}-${Task}-trainer. | |||
| """ | |||
| default = 'Trainer' | |||
| class Preprocessors(object): | |||
| """ Names for different preprocessor. | |||
| Holds the standard preprocessor name to use for identifying different preprocessor. | |||
| This should be used to register preprocessors. | |||
| For a general preprocessor, just use the function name as preprocessor name such as | |||
| resize-image, random-crop | |||
| For a model-specific preprocessor, use ${modelname}-${fuction} | |||
| """ | |||
| # cv preprocessor | |||
| load_image = 'load-image' | |||
| # nlp preprocessor | |||
| bert_seq_cls_tokenizer = 'bert-seq-cls-tokenizer' | |||
| palm_text_gen_tokenizer = 'palm-text-gen-tokenizer' | |||
| sbert_token_cls_tokenizer = 'sbert-token-cls-tokenizer' | |||
| # audio preprocessor | |||
| linear_aec_fbank = 'linear-aec-fbank' | |||
| text_to_tacotron_symbols = 'text-to-tacotron-symbols' | |||
| # multi-modal | |||
| ofa_image_caption = 'ofa-image-caption' | |||
| @@ -6,6 +6,7 @@ import numpy as np | |||
| import tensorflow as tf | |||
| from sklearn.preprocessing import MultiLabelBinarizer | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Model | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| @@ -26,7 +27,8 @@ def multi_label_symbol_to_sequence(my_classes, my_symbol): | |||
| return one_hot.fit_transform(sequences) | |||
| @MODELS.register_module(Tasks.text_to_speech, module_name=r'sambert_hifi_16k') | |||
| @MODELS.register_module( | |||
| Tasks.text_to_speech, module_name=Models.sambert_hifi_16k) | |||
| class SambertNetHifi16k(Model): | |||
| def __init__(self, | |||
| @@ -2,6 +2,7 @@ import os | |||
| import zipfile | |||
| from typing import Any, Dict, List | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Model | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.audio.tts_exceptions import ( | |||
| @@ -13,7 +14,7 @@ __all__ = ['GenericTtsFrontend'] | |||
| @MODELS.register_module( | |||
| Tasks.text_to_speech, module_name=r'generic_tts_frontend') | |||
| Tasks.text_to_speech, module_name=Models.generic_tts_frontend) | |||
| class GenericTtsFrontend(Model): | |||
| def __init__(self, model_dir='.', lang_type='pinyin', *args, **kwargs): | |||
| @@ -10,6 +10,7 @@ import numpy as np | |||
| import torch | |||
| from scipy.io.wavfile import write | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Model | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.audio.tts_exceptions import \ | |||
| @@ -36,7 +37,7 @@ class AttrDict(dict): | |||
| self.__dict__ = self | |||
| @MODELS.register_module(Tasks.text_to_speech, module_name=r'hifigan16k') | |||
| @MODELS.register_module(Tasks.text_to_speech, module_name=Models.hifigan16k) | |||
| class Hifigan16k(Model): | |||
| def __init__(self, model_dir, *args, **kwargs): | |||
| @@ -8,6 +8,9 @@ from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models.builder import build_model | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
| @@ -46,18 +49,24 @@ class Model(ABC): | |||
| local_model_dir = model_name_or_path | |||
| else: | |||
| local_model_dir = snapshot_download(model_name_or_path) | |||
| # else: | |||
| # raise ValueError( | |||
| # 'Remote model repo {model_name_or_path} does not exists') | |||
| logger.info(f'initialize model from {local_model_dir}') | |||
| cfg = Config.from_file( | |||
| osp.join(local_model_dir, ModelFile.CONFIGURATION)) | |||
| task_name = cfg.task | |||
| model_cfg = cfg.model | |||
| assert hasattr( | |||
| cfg, 'pipeline'), 'pipeline config is missing from config file.' | |||
| pipeline_cfg = cfg.pipeline | |||
| # TODO @wenmeng.zwm may should manually initialize model after model building | |||
| if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): | |||
| model_cfg.type = model_cfg.model_type | |||
| model_cfg.model_dir = local_model_dir | |||
| for k, v in kwargs.items(): | |||
| model_cfg.k = v | |||
| return build_model(model_cfg, task_name) | |||
| model = build_model(model_cfg, task_name) | |||
| # dynamically add pipeline info to model for pipeline inference | |||
| model.pipeline = pipeline_cfg | |||
| return model | |||
| @@ -3,6 +3,7 @@ from typing import Any, Dict | |||
| from PIL import Image | |||
| from modelscope.metainfo import Models | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from ..base import Model | |||
| from ..builder import MODELS | |||
| @@ -10,8 +11,7 @@ from ..builder import MODELS | |||
| __all__ = ['OfaForImageCaptioning'] | |||
| @MODELS.register_module( | |||
| Tasks.image_captioning, module_name=r'ofa-image-captioning') | |||
| @MODELS.register_module(Tasks.image_captioning, module_name=Models.ofa) | |||
| class OfaForImageCaptioning(Model): | |||
| def __init__(self, model_dir, *args, **kwargs): | |||
| @@ -4,6 +4,7 @@ from typing import Any, Dict | |||
| import json | |||
| import numpy as np | |||
| from modelscope.metainfo import Models | |||
| from modelscope.utils.constant import Tasks | |||
| from ..base import Model | |||
| from ..builder import MODELS | |||
| @@ -11,8 +12,7 @@ from ..builder import MODELS | |||
| __all__ = ['BertForSequenceClassification'] | |||
| @MODELS.register_module( | |||
| Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||
| @MODELS.register_module(Tasks.text_classification, module_name=Models.bert) | |||
| class BertForSequenceClassification(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -1,5 +1,6 @@ | |||
| from typing import Dict | |||
| from modelscope.metainfo import Models | |||
| from modelscope.utils.constant import Tasks | |||
| from ..base import Model, Tensor | |||
| from ..builder import MODELS | |||
| @@ -7,7 +8,7 @@ from ..builder import MODELS | |||
| __all__ = ['PalmForTextGeneration'] | |||
| @MODELS.register_module(Tasks.text_generation, module_name=r'palm2.0') | |||
| @MODELS.register_module(Tasks.text_generation, module_name=Models.palm2_0) | |||
| class PalmForTextGeneration(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -8,6 +8,7 @@ from sofa import SbertModel | |||
| from sofa.models.sbert.modeling_sbert import SbertPreTrainedModel | |||
| from torch import nn | |||
| from modelscope.metainfo import Models | |||
| from modelscope.utils.constant import Tasks | |||
| from ..base import Model, Tensor | |||
| from ..builder import MODELS | |||
| @@ -38,8 +39,7 @@ class SbertTextClassifier(SbertPreTrainedModel): | |||
| @MODELS.register_module( | |||
| Tasks.sentence_similarity, | |||
| module_name=r'sbert-base-chinese-sentence-similarity') | |||
| Tasks.sentence_similarity, module_name=Models.structbert) | |||
| class SbertForSentenceSimilarity(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -4,6 +4,7 @@ import numpy as np | |||
| import torch | |||
| from sofa import SbertConfig, SbertForTokenClassification | |||
| from modelscope.metainfo import Models | |||
| from modelscope.utils.constant import Tasks | |||
| from ..base import Model, Tensor | |||
| from ..builder import MODELS | |||
| @@ -11,9 +12,7 @@ from ..builder import MODELS | |||
| __all__ = ['StructBertForTokenClassification'] | |||
| @MODELS.register_module( | |||
| Tasks.word_segmentation, | |||
| module_name=r'structbert-chinese-word-segmentation') | |||
| @MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert) | |||
| class StructBertForTokenClassification(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -7,6 +7,7 @@ import scipy.io.wavfile as wav | |||
| import torch | |||
| import yaml | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.preprocessors.audio import LinearAECAndFbank | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from ..base import Pipeline | |||
| @@ -39,7 +40,8 @@ def initialize_config(module_cfg): | |||
| @PIPELINES.register_module( | |||
| Tasks.speech_signal_process, module_name=r'speech_dfsmn_aec_psm_16k') | |||
| Tasks.speech_signal_process, | |||
| module_name=Pipelines.speech_dfsmn_aec_psm_16k) | |||
| class LinearAECPipeline(Pipeline): | |||
| r"""AEC Inference Pipeline only support 16000 sample rate. | |||
| @@ -3,6 +3,7 @@ from typing import Any, Dict, List | |||
| import numpy as np | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.models.audio.tts.am import SambertNetHifi16k | |||
| from modelscope.models.audio.tts.vocoder import Hifigan16k | |||
| @@ -15,7 +16,7 @@ __all__ = ['TextToSpeechSambertHifigan16kPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.text_to_speech, module_name=r'tts-sambert-hifigan-16k') | |||
| Tasks.text_to_speech, module_name=Pipelines.sambert_hifigan_16k_tts) | |||
| class TextToSpeechSambertHifigan16kPipeline(Pipeline): | |||
| def __init__(self, | |||
| @@ -11,7 +11,7 @@ from modelscope.pydatasets import PyDataset | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.logger import get_logger | |||
| from .outputs import TASK_OUTPUTS | |||
| from .util import is_model_name | |||
| from .util import is_model, is_official_hub_path | |||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
| Input = Union[str, tuple, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] | |||
| @@ -27,12 +27,10 @@ class Pipeline(ABC): | |||
| def initiate_single_model(self, model): | |||
| logger.info(f'initiate model from {model}') | |||
| # TODO @wenmeng.zwm replace model.startswith('damo/') with get_model | |||
| if isinstance(model, str) and model.startswith('damo/'): | |||
| if not osp.exists(model): | |||
| model = snapshot_download(model) | |||
| return Model.from_pretrained(model) if is_model_name( | |||
| model) else model | |||
| if isinstance(model, str) and is_official_hub_path(model): | |||
| model = snapshot_download( | |||
| model) if not osp.exists(model) else model | |||
| return Model.from_pretrained(model) if is_model(model) else model | |||
| elif isinstance(model, Model): | |||
| return model | |||
| else: | |||
| @@ -3,32 +3,39 @@ | |||
| import os.path as osp | |||
| from typing import List, Union | |||
| from attr import has | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.base import Model | |||
| from modelscope.utils.config import Config, ConfigDict | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.hub import read_config | |||
| from modelscope.utils.registry import Registry, build_from_cfg | |||
| from .base import Pipeline | |||
| from .util import is_official_hub_path | |||
| PIPELINES = Registry('pipelines') | |||
| DEFAULT_MODEL_FOR_PIPELINE = { | |||
| # TaskName: (pipeline_module_name, model_repo) | |||
| Tasks.word_segmentation: | |||
| ('structbert-chinese-word-segmentation', | |||
| (Pipelines.word_segmentation, | |||
| 'damo/nlp_structbert_word-segmentation_chinese-base'), | |||
| Tasks.sentence_similarity: | |||
| ('sbert-base-chinese-sentence-similarity', | |||
| (Pipelines.sentence_similarity, | |||
| 'damo/nlp_structbert_sentence-similarity_chinese-base'), | |||
| Tasks.image_matting: ('image-matting', 'damo/cv_unet_image-matting'), | |||
| Tasks.text_classification: | |||
| ('bert-sentiment-analysis', 'damo/bert-base-sst2'), | |||
| Tasks.text_generation: ('palm2.0', | |||
| Tasks.image_matting: | |||
| (Pipelines.image_matting, 'damo/cv_unet_image-matting'), | |||
| Tasks.text_classification: (Pipelines.sentiment_analysis, | |||
| 'damo/bert-base-sst2'), | |||
| Tasks.text_generation: (Pipelines.text_generation, | |||
| 'damo/nlp_palm2.0_text-generation_chinese-base'), | |||
| Tasks.image_captioning: ('ofa', 'damo/ofa_image-caption_coco_large_en'), | |||
| Tasks.image_captioning: (Pipelines.image_caption, | |||
| 'damo/ofa_image-caption_coco_large_en'), | |||
| Tasks.image_generation: | |||
| ('person-image-cartoon', | |||
| (Pipelines.person_image_cartoon, | |||
| 'damo/cv_unet_person-image-cartoon_compound-models'), | |||
| Tasks.ocr_detection: ('ocr-detection', | |||
| Tasks.ocr_detection: (Pipelines.ocr_detection, | |||
| 'damo/cv_resnet18_ocr-detection-line-level_damo'), | |||
| } | |||
| @@ -86,30 +93,40 @@ def pipeline(task: str = None, | |||
| if task is None and pipeline_name is None: | |||
| raise ValueError('task or pipeline_name is required') | |||
| assert isinstance(model, (type(None), str, Model, list)), \ | |||
| f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' | |||
| if pipeline_name is None: | |||
| # get default pipeline for this task | |||
| if isinstance(model, str) \ | |||
| or (isinstance(model, list) and isinstance(model[0], str)): | |||
| # if is_model_name(model): | |||
| if (isinstance(model, str) and model.startswith('damo/')) \ | |||
| or (isinstance(model, list) and model[0].startswith('damo/')) \ | |||
| or (isinstance(model, str) and osp.exists(model)): | |||
| # TODO @wenmeng.zwm add support when model is a str of modelhub address | |||
| # read pipeline info from modelhub configuration file. | |||
| pipeline_name, default_model_repo = get_default_pipeline_info( | |||
| task) | |||
| if is_official_hub_path(model): | |||
| # read config file from hub and parse | |||
| cfg = read_config(model) if isinstance( | |||
| model, str) else read_config(model[0]) | |||
| assert hasattr( | |||
| cfg, | |||
| 'pipeline'), 'pipeline config is missing from config file.' | |||
| pipeline_name = cfg.pipeline.type | |||
| else: | |||
| # used for test case, when model is str and is not hub path | |||
| pipeline_name = get_pipeline_by_model_name(task, model) | |||
| elif isinstance(model, Model) or \ | |||
| (isinstance(model, list) and isinstance(model[0], Model)): | |||
| # get pipeline info from Model object | |||
| first_model = model[0] if isinstance(model, list) else model | |||
| if not hasattr(first_model, 'pipeline'): | |||
| # model is instantiated by user, we should parse config again | |||
| cfg = read_config(first_model.model_dir) | |||
| assert hasattr( | |||
| cfg, | |||
| 'pipeline'), 'pipeline config is missing from config file.' | |||
| first_model.pipeline = cfg.pipeline | |||
| pipeline_name = first_model.pipeline.type | |||
| else: | |||
| pipeline_name, default_model_repo = get_default_pipeline_info(task) | |||
| if model is None: | |||
| model = default_model_repo | |||
| assert isinstance(model, (type(None), str, Model, list)), \ | |||
| f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' | |||
| cfg = ConfigDict(type=pipeline_name, model=model) | |||
| if kwargs: | |||
| @@ -6,6 +6,7 @@ import numpy as np | |||
| import PIL | |||
| import tensorflow as tf | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.cartoon.facelib.facer import FaceAna | |||
| from modelscope.models.cv.cartoon.mtcnn_pytorch.src.align_trans import ( | |||
| get_reference_facial_points, warp_and_crop_face) | |||
| @@ -25,7 +26,7 @@ logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_generation, module_name='person-image-cartoon') | |||
| Tasks.image_generation, module_name=Pipelines.person_image_cartoon) | |||
| class ImageCartoonPipeline(Pipeline): | |||
| def __init__(self, model: str): | |||
| @@ -5,6 +5,7 @@ import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.pipelines.base import Input | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| @@ -16,7 +17,7 @@ logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_matting, module_name=Tasks.image_matting) | |||
| Tasks.image_matting, module_name=Pipelines.image_matting) | |||
| class ImageMattingPipeline(Pipeline): | |||
| def __init__(self, model: str): | |||
| @@ -10,6 +10,7 @@ import PIL | |||
| import tensorflow as tf | |||
| import tf_slim as slim | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.pipelines.base import Input | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| @@ -38,7 +39,7 @@ tf.app.flags.DEFINE_float('link_threshold', 0.6, | |||
| @PIPELINES.register_module( | |||
| Tasks.ocr_detection, module_name=Tasks.ocr_detection) | |||
| Tasks.ocr_detection, module_name=Pipelines.ocr_detection) | |||
| class OCRDetectionPipeline(Pipeline): | |||
| def __init__(self, model: str): | |||
| @@ -1,5 +1,6 @@ | |||
| from typing import Any, Dict, Union | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.preprocessors import OfaImageCaptionPreprocessor, Preprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| @@ -9,7 +10,8 @@ from ..builder import PIPELINES | |||
| logger = get_logger() | |||
| @PIPELINES.register_module(Tasks.image_captioning, module_name='ofa') | |||
| @PIPELINES.register_module( | |||
| Tasks.image_captioning, module_name=Pipelines.image_caption) | |||
| class ImageCaptionPipeline(Pipeline): | |||
| def __init__(self, | |||
| @@ -2,6 +2,7 @@ from typing import Any, Dict, Union | |||
| import numpy as np | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.nlp import SbertForSentenceSimilarity | |||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| @@ -13,8 +14,7 @@ __all__ = ['SentenceSimilarityPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.sentence_similarity, | |||
| module_name=r'sbert-base-chinese-sentence-similarity') | |||
| Tasks.sentence_similarity, module_name=Pipelines.sentence_similarity) | |||
| class SentenceSimilarityPipeline(Pipeline): | |||
| def __init__(self, | |||
| @@ -2,6 +2,7 @@ from typing import Any, Dict, Union | |||
| import numpy as np | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.nlp import BertForSequenceClassification | |||
| from modelscope.preprocessors import SequenceClassificationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| @@ -13,7 +14,7 @@ __all__ = ['SequenceClassificationPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||
| Tasks.text_classification, module_name=Pipelines.sentiment_analysis) | |||
| class SequenceClassificationPipeline(Pipeline): | |||
| def __init__(self, | |||
| @@ -1,5 +1,6 @@ | |||
| from typing import Dict, Optional, Union | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import PalmForTextGeneration | |||
| from modelscope.preprocessors import TextGenerationPreprocessor | |||
| @@ -10,7 +11,8 @@ from ..builder import PIPELINES | |||
| __all__ = ['TextGenerationPipeline'] | |||
| @PIPELINES.register_module(Tasks.text_generation, module_name=r'palm2.0') | |||
| @PIPELINES.register_module( | |||
| Tasks.text_generation, module_name=Pipelines.text_generation) | |||
| class TextGenerationPipeline(Pipeline): | |||
| def __init__(self, | |||
| @@ -1,5 +1,6 @@ | |||
| from typing import Any, Dict, Optional, Union | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models import Model | |||
| from modelscope.models.nlp import StructBertForTokenClassification | |||
| from modelscope.preprocessors import TokenClassifcationPreprocessor | |||
| @@ -11,8 +12,7 @@ __all__ = ['WordSegmentationPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.word_segmentation, | |||
| module_name=r'structbert-chinese-word-segmentation') | |||
| Tasks.word_segmentation, module_name=Pipelines.word_segmentation) | |||
| class WordSegmentationPipeline(Pipeline): | |||
| def __init__(self, | |||
| @@ -2,6 +2,7 @@ | |||
| import os.path as osp | |||
| from typing import List, Union | |||
| from modelscope.hub.api import HubApi | |||
| from modelscope.hub.file_download import model_file_download | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile | |||
| @@ -19,31 +20,63 @@ def is_config_has_model(cfg_file): | |||
| return False | |||
| def is_model_name(model: Union[str, List]): | |||
| """ whether model is a valid modelhub path | |||
| def is_official_hub_path(path: Union[str, List]): | |||
| """ Whether path is a official hub name or a valid local | |||
| path to official hub directory. | |||
| """ | |||
| def is_model_name_impl(model): | |||
| if osp.exists(model): | |||
| cfg_file = osp.join(model, ModelFile.CONFIGURATION) | |||
| def is_official_hub_impl(path): | |||
| if osp.exists(path): | |||
| cfg_file = osp.join(path, ModelFile.CONFIGURATION) | |||
| return osp.exists(cfg_file) | |||
| else: | |||
| try: | |||
| _ = HubApi().get_model(path) | |||
| return True | |||
| except Exception: | |||
| return False | |||
| if isinstance(path, str): | |||
| return is_official_hub_impl(path) | |||
| else: | |||
| results = [is_official_hub_impl(m) for m in path] | |||
| all_true = all(results) | |||
| any_true = any(results) | |||
| if any_true and not all_true: | |||
| raise ValueError( | |||
| f'some model are hub address, some are not, model list: {path}' | |||
| ) | |||
| return all_true | |||
| def is_model(path: Union[str, List]): | |||
| """ whether path is a valid modelhub path and containing model config | |||
| """ | |||
| def is_modelhub_path_impl(path): | |||
| if osp.exists(path): | |||
| cfg_file = osp.join(path, ModelFile.CONFIGURATION) | |||
| if osp.exists(cfg_file): | |||
| return is_config_has_model(cfg_file) | |||
| else: | |||
| return False | |||
| else: | |||
| try: | |||
| cfg_file = model_file_download(model, ModelFile.CONFIGURATION) | |||
| cfg_file = model_file_download(path, ModelFile.CONFIGURATION) | |||
| return is_config_has_model(cfg_file) | |||
| except Exception: | |||
| return False | |||
| if isinstance(model, str): | |||
| return is_model_name_impl(model) | |||
| if isinstance(path, str): | |||
| return is_modelhub_path_impl(path) | |||
| else: | |||
| results = [is_model_name_impl(m) for m in model] | |||
| results = [is_modelhub_path_impl(m) for m in path] | |||
| all_true = all(results) | |||
| any_true = any(results) | |||
| if any_true and not all_true: | |||
| raise ValueError('some model are hub address, some are not') | |||
| raise ValueError( | |||
| f'some models are hub address, some are not, model list: {path}' | |||
| ) | |||
| return all_true | |||
| @@ -5,11 +5,12 @@ from typing import Dict, Union | |||
| from PIL import Image, ImageOps | |||
| from modelscope.fileio import File | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.utils.constant import Fields | |||
| from .builder import PREPROCESSORS | |||
| @PREPROCESSORS.register_module(Fields.cv) | |||
| @PREPROCESSORS.register_module(Fields.cv, Preprocessors.load_image) | |||
| class LoadImage: | |||
| """Load an image from file or url. | |||
| Added or updated keys are "filename", "img", "img_shape", | |||
| @@ -7,6 +7,7 @@ import torch | |||
| from PIL import Image | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.utils.constant import Fields, ModelFile | |||
| from modelscope.utils.type_assert import type_assert | |||
| from .base import Preprocessor | |||
| @@ -19,7 +20,7 @@ __all__ = [ | |||
| @PREPROCESSORS.register_module( | |||
| Fields.multi_modal, module_name=r'ofa-image-caption') | |||
| Fields.multi_modal, module_name=Preprocessors.ofa_image_caption) | |||
| class OfaImageCaptionPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -5,6 +5,7 @@ from typing import Any, Dict, Union | |||
| from transformers import AutoTokenizer | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.utils.constant import Fields, InputFields | |||
| from modelscope.utils.type_assert import type_assert | |||
| from .base import Preprocessor | |||
| @@ -31,7 +32,7 @@ class Tokenize(Preprocessor): | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=r'bert-sequence-classification') | |||
| Fields.nlp, module_name=Preprocessors.bert_seq_cls_tokenizer) | |||
| class SequenceClassificationPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -124,7 +125,8 @@ class SequenceClassificationPreprocessor(Preprocessor): | |||
| return rst | |||
| @PREPROCESSORS.register_module(Fields.nlp, module_name=r'palm2.0') | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.palm_text_gen_tokenizer) | |||
| class TextGenerationPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, tokenizer, *args, **kwargs): | |||
| @@ -180,7 +182,7 @@ class TextGenerationPreprocessor(Preprocessor): | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=r'bert-token-classification') | |||
| Fields.nlp, module_name=Preprocessors.sbert_token_cls_tokenizer) | |||
| class TokenClassifcationPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| @@ -3,6 +3,7 @@ import io | |||
| from typing import Any, Dict, Union | |||
| from modelscope.fileio import File | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.models.audio.tts.frontend import GenericTtsFrontend | |||
| from modelscope.models.base import Model | |||
| from modelscope.utils.audio.tts_exceptions import * # noqa F403 | |||
| @@ -10,11 +11,11 @@ from modelscope.utils.constant import Fields | |||
| from .base import Preprocessor | |||
| from .builder import PREPROCESSORS | |||
| __all__ = ['TextToTacotronSymbols', 'text_to_tacotron_symbols'] | |||
| __all__ = ['TextToTacotronSymbols'] | |||
| @PREPROCESSORS.register_module( | |||
| Fields.audio, module_name=r'text_to_tacotron_symbols') | |||
| Fields.audio, module_name=Preprocessors.text_to_tacotron_symbols) | |||
| class TextToTacotronSymbols(Preprocessor): | |||
| """extract tacotron symbols from text. | |||
| @@ -1,11 +1,49 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| from typing import List, Union | |||
| from modelscope.hub.constants import MODEL_ID_SEPARATOR | |||
| from numpy import deprecate | |||
| from modelscope.hub.file_download import model_file_download | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.hub.utils.utils import get_cache_dir | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile | |||
| # temp solution before the hub-cache is in place | |||
| @deprecate | |||
| def get_model_cache_dir(model_id: str): | |||
| return os.path.join(get_cache_dir(), model_id) | |||
| def read_config(model_id_or_path: str): | |||
| """ Read config from hub or local path | |||
| Args: | |||
| model_id_or_path (str): Model repo name or local directory path. | |||
| Return: | |||
| config (:obj:`Config`): config object | |||
| """ | |||
| if not os.path.exists(model_id_or_path): | |||
| local_path = model_file_download(model_id_or_path, | |||
| ModelFile.CONFIGURATION) | |||
| else: | |||
| local_path = os.path.join(model_id_or_path, ModelFile.CONFIGURATION) | |||
| return Config.from_file(local_path) | |||
| def auto_load(model: Union[str, List[str]]): | |||
| if isinstance(model, str): | |||
| if not osp.exists(model): | |||
| model = snapshot_download(model) | |||
| else: | |||
| model = [ | |||
| snapshot_download(m) if not osp.exists(m) else m for m in model | |||
| ] | |||
| return model | |||
| @@ -1,10 +1,10 @@ | |||
| #tts | |||
| h5py==2.10.0 | |||
| #https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp36-cp36m-linux_x86_64.whl | |||
| https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp37-cp37m-linux_x86_64.whl | |||
| https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp36-cp36m-linux_x86_64.whl; python_version=='3.6' | |||
| https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp37-cp37m-linux_x86_64.whl; python_version=='3.7' | |||
| https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp38-cp38-linux_x86_64.whl; python_version=='3.8' | |||
| https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp39-cp39-linux_x86_64.whl; python_version=='3.9' | |||
| https://swap.oss-cn-hangzhou.aliyuncs.com/Jiaqi%2Fmaas%2Ftts%2Frequirements%2Fpytorch_wavelets-1.3.0-py3-none-any.whl?Expires=1685688388&OSSAccessKeyId=LTAI4Ffebq4d9jTVDwiSbY4L&Signature=jcQbg5EZ%2Bdys3%2F4BRn3srrKLdIg%3D | |||
| #https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp38-cp38-linux_x86_64.whl | |||
| #https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp39-cp39-linux_x86_64.whl | |||
| inflect | |||
| keras==2.2.4 | |||
| librosa | |||
| @@ -12,7 +12,7 @@ lxml | |||
| matplotlib | |||
| nara_wpe | |||
| numpy==1.18.* | |||
| protobuf==3.20.* | |||
| protobuf>3,<=3.20 | |||
| ptflops | |||
| PyWavelets>=1.0.0 | |||
| scikit-learn==0.23.2 | |||
| @@ -60,7 +60,7 @@ class ImageMattingTest(unittest.TestCase): | |||
| cv2.imwrite('result.png', result['output_png']) | |||
| print(f'Output written to {osp.abspath("result.png")}') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_modelscope_dataset(self): | |||
| dataset = PyDataset.load('beans', split='train', target='image') | |||
| img_matting = pipeline(Tasks.image_matting, model=self.model_id) | |||
| @@ -3,6 +3,7 @@ import shutil | |||
| import unittest | |||
| from modelscope.fileio import File | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| @@ -42,7 +43,7 @@ class SpeechSignalProcessTest(unittest.TestCase): | |||
| aec = pipeline( | |||
| Tasks.speech_signal_process, | |||
| model=self.model_id, | |||
| pipeline_name=r'speech_dfsmn_aec_psm_16k') | |||
| pipeline_name=Pipelines.speech_dfsmn_aec_psm_16k) | |||
| aec(input, output_path='output.wav') | |||
| @@ -38,31 +38,6 @@ class SequenceClassificationTest(unittest.TestCase): | |||
| break | |||
| print(r) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run(self): | |||
| model_url = 'https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com' \ | |||
| '/release/easynlp_modelzoo/alibaba-pai/bert-base-sst2.zip' | |||
| cache_path_str = r'.cache/easynlp/bert-base-sst2.zip' | |||
| cache_path = Path(cache_path_str) | |||
| if not cache_path.exists(): | |||
| cache_path.parent.mkdir(parents=True, exist_ok=True) | |||
| cache_path.touch(exist_ok=True) | |||
| with cache_path.open('wb') as ofile: | |||
| ofile.write(File.read(model_url)) | |||
| with zipfile.ZipFile(cache_path_str, 'r') as zipf: | |||
| zipf.extractall(cache_path.parent) | |||
| path = r'.cache/easynlp/' | |||
| model = BertForSequenceClassification(path) | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| path, first_sequence='sentence', second_sequence=None) | |||
| pipeline1 = SequenceClassificationPipeline(model, preprocessor) | |||
| self.predict(pipeline1) | |||
| pipeline2 = pipeline( | |||
| Tasks.text_classification, model=model, preprocessor=preprocessor) | |||
| print(pipeline2('Hello world!')) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| @@ -11,6 +11,7 @@ import torch | |||
| from scipy.io.wavfile import write | |||
| from modelscope.fileio import File | |||
| from modelscope.metainfo import Pipelines, Preprocessors | |||
| from modelscope.models import Model, build_model | |||
| from modelscope.models.audio.tts.am import SambertNetHifi16k | |||
| from modelscope.models.audio.tts.vocoder import AttrDict, Hifigan16k | |||
| @@ -32,7 +33,7 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase): | |||
| voc_model_id = 'damo/speech_hifigan16k_tts_zhitian_emo' | |||
| cfg_preprocessor = dict( | |||
| type='text_to_tacotron_symbols', | |||
| type=Preprocessors.text_to_tacotron_symbols, | |||
| model_name=preprocessor_model_id, | |||
| lang_type=lang_type) | |||
| preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) | |||
| @@ -45,7 +46,7 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase): | |||
| self.assertTrue(voc is not None) | |||
| sambert_tts = pipeline( | |||
| pipeline_name='tts-sambert-hifigan-16k', | |||
| pipeline_name=Pipelines.sambert_hifigan_16k_tts, | |||
| config_file='', | |||
| model=[am, voc], | |||
| preprocessor=preprocessor) | |||
| @@ -1,6 +1,7 @@ | |||
| import shutil | |||
| import unittest | |||
| from modelscope.metainfo import Preprocessors | |||
| from modelscope.preprocessors import build_preprocessor | |||
| from modelscope.utils.constant import Fields, InputFields | |||
| from modelscope.utils.logger import get_logger | |||
| @@ -14,7 +15,7 @@ class TtsPreprocessorTest(unittest.TestCase): | |||
| lang_type = 'pinyin' | |||
| text = '今天天气不错,我们去散步吧。' | |||
| cfg = dict( | |||
| type='text_to_tacotron_symbols', | |||
| type=Preprocessors.text_to_tacotron_symbols, | |||
| model_name='damo/speech_binary_tts_frontend_resource', | |||
| lang_type=lang_type) | |||
| preprocessor = build_preprocessor(cfg, Fields.audio) | |||
| @@ -33,6 +33,8 @@ class ImgPreprocessor(Preprocessor): | |||
| class PyDatasetTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 2, | |||
| 'skip test due to dataset api problem') | |||
| def test_ds_basic(self): | |||
| ms_ds_full = PyDataset.load('squad') | |||
| ms_ds_full_hf = hfdata.load_dataset('squad') | |||