Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8997599master
| @@ -11,6 +11,7 @@ from modelscope.preprocessors import Preprocessor | |||
| from modelscope.pydatasets import PyDataset | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.hub import get_model_cache_dir | |||
| from modelscope.utils.logger import get_logger | |||
| from .util import is_model_name | |||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
| @@ -20,11 +21,15 @@ InputModel = Union[str, Model] | |||
| output_keys = [ | |||
| ] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key | |||
| logger = get_logger() | |||
| class Pipeline(ABC): | |||
| def initiate_single_model(self, model): | |||
| if isinstance(model, str): | |||
| logger.info(f'initiate model from {model}') | |||
| # TODO @wenmeng.zwm replace model.startswith('damo/') with get_model | |||
| if isinstance(model, str) and model.startswith('damo/'): | |||
| if not osp.exists(model): | |||
| cache_path = get_model_cache_dir(model) | |||
| model = cache_path if osp.exists( | |||
| @@ -34,10 +39,11 @@ class Pipeline(ABC): | |||
| elif isinstance(model, Model): | |||
| return model | |||
| else: | |||
| if model: | |||
| if model and not isinstance(model, str): | |||
| raise ValueError( | |||
| f'model type for single model is either str or Model, but got type {type(model)}' | |||
| ) | |||
| return model | |||
| def initiate_multiple_models(self, input_models: List[InputModel]): | |||
| models = [] | |||
| @@ -1,7 +1,7 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| from typing import Union | |||
| from typing import List, Union | |||
| import json | |||
| from maas_hub.file_download import model_file_download | |||
| @@ -44,7 +44,7 @@ def build_pipeline(cfg: ConfigDict, | |||
| def pipeline(task: str = None, | |||
| model: Union[str, Model] = None, | |||
| model: Union[str, List[str], Model, List[Model]] = None, | |||
| preprocessor=None, | |||
| config_file: str = None, | |||
| pipeline_name: str = None, | |||
| @@ -56,7 +56,7 @@ def pipeline(task: str = None, | |||
| Args: | |||
| task (str): Task name defining which pipeline will be returned. | |||
| model (str or obj:`Model`): model name or model object. | |||
| model (str or List[str] or obj:`Model` or obj:list[`Model`]): (list of) model name or model object. | |||
| preprocessor: preprocessor object. | |||
| config_file (str, optional): path to config file. | |||
| pipeline_name (str, optional): pipeline class name or alias name. | |||
| @@ -68,23 +68,42 @@ def pipeline(task: str = None, | |||
| Examples: | |||
| ```python | |||
| >>> # Using default model for a task | |||
| >>> p = pipeline('image-classification') | |||
| >>> p = pipeline('text-classification', model='distilbert-base-uncased') | |||
| >>> # Using model object | |||
| >>> # Using pipeline with a model name | |||
| >>> p = pipeline('text-classification', model='damo/distilbert-base-uncased') | |||
| >>> # Using pipeline with a model object | |||
| >>> resnet = Model.from_pretrained('Resnet') | |||
| >>> p = pipeline('image-classification', model=resnet) | |||
| >>> # Using pipeline with a list of model names | |||
| >>> p = pipeline('audio-kws', model=['damo/audio-tts', 'damo/auto-tts2']) | |||
| """ | |||
| if task is None and pipeline_name is None: | |||
| raise ValueError('task or pipeline_name is required') | |||
| if pipeline_name is None: | |||
| # get default pipeline for this task | |||
| pipeline_name, default_model_repo = get_default_pipeline_info(task) | |||
| if isinstance(model, str) \ | |||
| or (isinstance(model, list) and isinstance(model[0], str)): | |||
| # if is_model_name(model): | |||
| if (isinstance(model, str) and model.startswith('damo/')) \ | |||
| or (isinstance(model, list) and model[0].startswith('damo/')) \ | |||
| or (isinstance(model, str) and osp.exists(model)): | |||
| # TODO @wenmeng.zwm add support when model is a str of modelhub address | |||
| # read pipeline info from modelhub configuration file. | |||
| pipeline_name, default_model_repo = get_default_pipeline_info( | |||
| task) | |||
| else: | |||
| pipeline_name = get_pipeline_by_model_name(task, model) | |||
| else: | |||
| pipeline_name, default_model_repo = get_default_pipeline_info(task) | |||
| if model is None: | |||
| model = default_model_repo | |||
| assert isinstance(model, (type(None), str, Model)), \ | |||
| f'model should be either None, str or Model, but got {type(model)}' | |||
| assert isinstance(model, (type(None), str, Model, list)), \ | |||
| f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' | |||
| cfg = ConfigDict(type=pipeline_name, model=model) | |||
| @@ -134,3 +153,19 @@ def get_default_pipeline_info(task): | |||
| else: | |||
| pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task] | |||
| return pipeline_name, default_model | |||
| def get_pipeline_by_model_name(task: str, model: Union[str, List[str]]): | |||
| """ Get pipeline name by task name and model name | |||
| Args: | |||
| task (str): task name. | |||
| model (str| list[str]): model names | |||
| """ | |||
| if isinstance(model, str): | |||
| model_key = model | |||
| else: | |||
| model_key = '_'.join(model) | |||
| assert model_key in PIPELINES.modules[task], \ | |||
| f'pipeline for task {task} model {model_key} not found.' | |||
| return model_key | |||
| @@ -1,6 +1,7 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| from typing import List, Union | |||
| import json | |||
| from maas_hub.file_download import model_file_download | |||
| @@ -8,23 +9,38 @@ from maas_hub.file_download import model_file_download | |||
| from modelscope.utils.constant import CONFIGFILE | |||
| def is_model_name(model): | |||
| if osp.exists(model): | |||
| if osp.exists(osp.join(model, CONFIGFILE)): | |||
| return True | |||
| def is_model_name(model: Union[str, List]): | |||
| """ whether model is a valid modelhub path | |||
| """ | |||
| def is_model_name_impl(model): | |||
| if osp.exists(model): | |||
| if osp.exists(osp.join(model, CONFIGFILE)): | |||
| return True | |||
| else: | |||
| return False | |||
| else: | |||
| return False | |||
| # try: | |||
| # cfg_file = model_file_download(model, CONFIGFILE) | |||
| # except Exception: | |||
| # cfg_file = None | |||
| # TODO @wenmeng.zwm use exception instead of | |||
| # following tricky logic | |||
| cfg_file = model_file_download(model, CONFIGFILE) | |||
| with open(cfg_file, 'r') as infile: | |||
| cfg = json.load(infile) | |||
| if 'Code' in cfg: | |||
| return False | |||
| else: | |||
| return True | |||
| if isinstance(model, str): | |||
| return is_model_name_impl(model) | |||
| else: | |||
| # try: | |||
| # cfg_file = model_file_download(model, CONFIGFILE) | |||
| # except Exception: | |||
| # cfg_file = None | |||
| # TODO @wenmeng.zwm use exception instead of | |||
| # following tricky logic | |||
| cfg_file = model_file_download(model, CONFIGFILE) | |||
| with open(cfg_file, 'r') as infile: | |||
| cfg = json.load(infile) | |||
| if 'Code' in cfg: | |||
| return False | |||
| else: | |||
| return True | |||
| results = [is_model_name_impl(m) for m in model] | |||
| all_true = all(results) | |||
| any_true = any(results) | |||
| if any_true and not all_true: | |||
| raise ValueError('some model are hub address, some are not') | |||
| return all_true | |||
| @@ -0,0 +1,68 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from asyncio import Task | |||
| from typing import Any, Dict, List, Tuple, Union | |||
| import numpy as np | |||
| import PIL | |||
| from modelscope.models.base import Model | |||
| from modelscope.pipelines import Pipeline, pipeline | |||
| from modelscope.pipelines.builder import PIPELINES, add_default_pipeline_info | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.registry import default_group | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| group_key=Tasks.image_tagging, module_name='custom_single_model') | |||
| class CustomSingleModelPipeline(Pipeline): | |||
| def __init__(self, | |||
| config_file: str = None, | |||
| model: List[Union[str, Model]] = None, | |||
| preprocessor=None, | |||
| **kwargs): | |||
| super().__init__(config_file, model, preprocessor, **kwargs) | |||
| assert isinstance(model, str), 'model is not str' | |||
| print(model) | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return super().postprocess(inputs) | |||
| @PIPELINES.register_module( | |||
| group_key=Tasks.image_tagging, module_name='model1_model2') | |||
| class CustomMultiModelPipeline(Pipeline): | |||
| def __init__(self, | |||
| config_file: str = None, | |||
| model: List[Union[str, Model]] = None, | |||
| preprocessor=None, | |||
| **kwargs): | |||
| super().__init__(config_file, model, preprocessor, **kwargs) | |||
| assert isinstance(model, list), 'model is not list' | |||
| for m in model: | |||
| assert isinstance(m, str), 'submodel is not str' | |||
| print(m) | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return super().postprocess(inputs) | |||
| class PipelineInterfaceTest(unittest.TestCase): | |||
| def test_single_model(self): | |||
| pipe = pipeline(Tasks.image_tagging, model='custom_single_model') | |||
| assert isinstance(pipe, CustomSingleModelPipeline) | |||
| def test_multi_model(self): | |||
| pipe = pipeline(Tasks.image_tagging, model=['model1', 'model2']) | |||
| assert isinstance(pipe, CustomMultiModelPipeline) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||