Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8997599master
| @@ -11,6 +11,7 @@ from modelscope.preprocessors import Preprocessor | |||||
| from modelscope.pydatasets import PyDataset | from modelscope.pydatasets import PyDataset | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.hub import get_model_cache_dir | from modelscope.utils.hub import get_model_cache_dir | ||||
| from modelscope.utils.logger import get_logger | |||||
| from .util import is_model_name | from .util import is_model_name | ||||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | Tensor = Union['torch.Tensor', 'tf.Tensor'] | ||||
| @@ -20,11 +21,15 @@ InputModel = Union[str, Model] | |||||
| output_keys = [ | output_keys = [ | ||||
| ] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key | ] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key | ||||
| logger = get_logger() | |||||
| class Pipeline(ABC): | class Pipeline(ABC): | ||||
| def initiate_single_model(self, model): | def initiate_single_model(self, model): | ||||
| if isinstance(model, str): | |||||
| logger.info(f'initiate model from {model}') | |||||
| # TODO @wenmeng.zwm replace model.startswith('damo/') with get_model | |||||
| if isinstance(model, str) and model.startswith('damo/'): | |||||
| if not osp.exists(model): | if not osp.exists(model): | ||||
| cache_path = get_model_cache_dir(model) | cache_path = get_model_cache_dir(model) | ||||
| model = cache_path if osp.exists( | model = cache_path if osp.exists( | ||||
| @@ -34,10 +39,11 @@ class Pipeline(ABC): | |||||
| elif isinstance(model, Model): | elif isinstance(model, Model): | ||||
| return model | return model | ||||
| else: | else: | ||||
| if model: | |||||
| if model and not isinstance(model, str): | |||||
| raise ValueError( | raise ValueError( | ||||
| f'model type for single model is either str or Model, but got type {type(model)}' | f'model type for single model is either str or Model, but got type {type(model)}' | ||||
| ) | ) | ||||
| return model | |||||
| def initiate_multiple_models(self, input_models: List[InputModel]): | def initiate_multiple_models(self, input_models: List[InputModel]): | ||||
| models = [] | models = [] | ||||
| @@ -1,7 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os.path as osp | import os.path as osp | ||||
| from typing import Union | |||||
| from typing import List, Union | |||||
| import json | import json | ||||
| from maas_hub.file_download import model_file_download | from maas_hub.file_download import model_file_download | ||||
| @@ -44,7 +44,7 @@ def build_pipeline(cfg: ConfigDict, | |||||
| def pipeline(task: str = None, | def pipeline(task: str = None, | ||||
| model: Union[str, Model] = None, | |||||
| model: Union[str, List[str], Model, List[Model]] = None, | |||||
| preprocessor=None, | preprocessor=None, | ||||
| config_file: str = None, | config_file: str = None, | ||||
| pipeline_name: str = None, | pipeline_name: str = None, | ||||
| @@ -56,7 +56,7 @@ def pipeline(task: str = None, | |||||
| Args: | Args: | ||||
| task (str): Task name defining which pipeline will be returned. | task (str): Task name defining which pipeline will be returned. | ||||
| model (str or obj:`Model`): model name or model object. | |||||
| model (str or List[str] or obj:`Model` or obj:list[`Model`]): (list of) model name or model object. | |||||
| preprocessor: preprocessor object. | preprocessor: preprocessor object. | ||||
| config_file (str, optional): path to config file. | config_file (str, optional): path to config file. | ||||
| pipeline_name (str, optional): pipeline class name or alias name. | pipeline_name (str, optional): pipeline class name or alias name. | ||||
| @@ -68,23 +68,42 @@ def pipeline(task: str = None, | |||||
| Examples: | Examples: | ||||
| ```python | ```python | ||||
| >>> # Using default model for a task | |||||
| >>> p = pipeline('image-classification') | >>> p = pipeline('image-classification') | ||||
| >>> p = pipeline('text-classification', model='distilbert-base-uncased') | |||||
| >>> # Using model object | |||||
| >>> # Using pipeline with a model name | |||||
| >>> p = pipeline('text-classification', model='damo/distilbert-base-uncased') | |||||
| >>> # Using pipeline with a model object | |||||
| >>> resnet = Model.from_pretrained('Resnet') | >>> resnet = Model.from_pretrained('Resnet') | ||||
| >>> p = pipeline('image-classification', model=resnet) | >>> p = pipeline('image-classification', model=resnet) | ||||
| >>> # Using pipeline with a list of model names | |||||
| >>> p = pipeline('audio-kws', model=['damo/audio-tts', 'damo/auto-tts2']) | |||||
| """ | """ | ||||
| if task is None and pipeline_name is None: | if task is None and pipeline_name is None: | ||||
| raise ValueError('task or pipeline_name is required') | raise ValueError('task or pipeline_name is required') | ||||
| if pipeline_name is None: | if pipeline_name is None: | ||||
| # get default pipeline for this task | # get default pipeline for this task | ||||
| pipeline_name, default_model_repo = get_default_pipeline_info(task) | |||||
| if isinstance(model, str) \ | |||||
| or (isinstance(model, list) and isinstance(model[0], str)): | |||||
| # if is_model_name(model): | |||||
| if (isinstance(model, str) and model.startswith('damo/')) \ | |||||
| or (isinstance(model, list) and model[0].startswith('damo/')) \ | |||||
| or (isinstance(model, str) and osp.exists(model)): | |||||
| # TODO @wenmeng.zwm add support when model is a str of modelhub address | |||||
| # read pipeline info from modelhub configuration file. | |||||
| pipeline_name, default_model_repo = get_default_pipeline_info( | |||||
| task) | |||||
| else: | |||||
| pipeline_name = get_pipeline_by_model_name(task, model) | |||||
| else: | |||||
| pipeline_name, default_model_repo = get_default_pipeline_info(task) | |||||
| if model is None: | if model is None: | ||||
| model = default_model_repo | model = default_model_repo | ||||
| assert isinstance(model, (type(None), str, Model)), \ | |||||
| f'model should be either None, str or Model, but got {type(model)}' | |||||
| assert isinstance(model, (type(None), str, Model, list)), \ | |||||
| f'model should be either None, str, List[str], Model, or List[Model], but got {type(model)}' | |||||
| cfg = ConfigDict(type=pipeline_name, model=model) | cfg = ConfigDict(type=pipeline_name, model=model) | ||||
| @@ -134,3 +153,19 @@ def get_default_pipeline_info(task): | |||||
| else: | else: | ||||
| pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task] | pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task] | ||||
| return pipeline_name, default_model | return pipeline_name, default_model | ||||
| def get_pipeline_by_model_name(task: str, model: Union[str, List[str]]): | |||||
| """ Get pipeline name by task name and model name | |||||
| Args: | |||||
| task (str): task name. | |||||
| model (str| list[str]): model names | |||||
| """ | |||||
| if isinstance(model, str): | |||||
| model_key = model | |||||
| else: | |||||
| model_key = '_'.join(model) | |||||
| assert model_key in PIPELINES.modules[task], \ | |||||
| f'pipeline for task {task} model {model_key} not found.' | |||||
| return model_key | |||||
| @@ -1,6 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | import os | ||||
| import os.path as osp | import os.path as osp | ||||
| from typing import List, Union | |||||
| import json | import json | ||||
| from maas_hub.file_download import model_file_download | from maas_hub.file_download import model_file_download | ||||
| @@ -8,23 +9,38 @@ from maas_hub.file_download import model_file_download | |||||
| from modelscope.utils.constant import CONFIGFILE | from modelscope.utils.constant import CONFIGFILE | ||||
| def is_model_name(model): | |||||
| if osp.exists(model): | |||||
| if osp.exists(osp.join(model, CONFIGFILE)): | |||||
| return True | |||||
| def is_model_name(model: Union[str, List]): | |||||
| """ whether model is a valid modelhub path | |||||
| """ | |||||
| def is_model_name_impl(model): | |||||
| if osp.exists(model): | |||||
| if osp.exists(osp.join(model, CONFIGFILE)): | |||||
| return True | |||||
| else: | |||||
| return False | |||||
| else: | else: | ||||
| return False | |||||
| # try: | |||||
| # cfg_file = model_file_download(model, CONFIGFILE) | |||||
| # except Exception: | |||||
| # cfg_file = None | |||||
| # TODO @wenmeng.zwm use exception instead of | |||||
| # following tricky logic | |||||
| cfg_file = model_file_download(model, CONFIGFILE) | |||||
| with open(cfg_file, 'r') as infile: | |||||
| cfg = json.load(infile) | |||||
| if 'Code' in cfg: | |||||
| return False | |||||
| else: | |||||
| return True | |||||
| if isinstance(model, str): | |||||
| return is_model_name_impl(model) | |||||
| else: | else: | ||||
| # try: | |||||
| # cfg_file = model_file_download(model, CONFIGFILE) | |||||
| # except Exception: | |||||
| # cfg_file = None | |||||
| # TODO @wenmeng.zwm use exception instead of | |||||
| # following tricky logic | |||||
| cfg_file = model_file_download(model, CONFIGFILE) | |||||
| with open(cfg_file, 'r') as infile: | |||||
| cfg = json.load(infile) | |||||
| if 'Code' in cfg: | |||||
| return False | |||||
| else: | |||||
| return True | |||||
| results = [is_model_name_impl(m) for m in model] | |||||
| all_true = all(results) | |||||
| any_true = any(results) | |||||
| if any_true and not all_true: | |||||
| raise ValueError('some model are hub address, some are not') | |||||
| return all_true | |||||
| @@ -0,0 +1,68 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from asyncio import Task | |||||
| from typing import Any, Dict, List, Tuple, Union | |||||
| import numpy as np | |||||
| import PIL | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.pipelines import Pipeline, pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES, add_default_pipeline_info | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.registry import default_group | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| group_key=Tasks.image_tagging, module_name='custom_single_model') | |||||
| class CustomSingleModelPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| config_file: str = None, | |||||
| model: List[Union[str, Model]] = None, | |||||
| preprocessor=None, | |||||
| **kwargs): | |||||
| super().__init__(config_file, model, preprocessor, **kwargs) | |||||
| assert isinstance(model, str), 'model is not str' | |||||
| print(model) | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return super().postprocess(inputs) | |||||
| @PIPELINES.register_module( | |||||
| group_key=Tasks.image_tagging, module_name='model1_model2') | |||||
| class CustomMultiModelPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| config_file: str = None, | |||||
| model: List[Union[str, Model]] = None, | |||||
| preprocessor=None, | |||||
| **kwargs): | |||||
| super().__init__(config_file, model, preprocessor, **kwargs) | |||||
| assert isinstance(model, list), 'model is not list' | |||||
| for m in model: | |||||
| assert isinstance(m, str), 'submodel is not str' | |||||
| print(m) | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return super().postprocess(inputs) | |||||
| class PipelineInterfaceTest(unittest.TestCase): | |||||
| def test_single_model(self): | |||||
| pipe = pipeline(Tasks.image_tagging, model='custom_single_model') | |||||
| assert isinstance(pipe, CustomSingleModelPipeline) | |||||
| def test_multi_model(self): | |||||
| pipe = pipeline(Tasks.image_tagging, model=['model1', 'model2']) | |||||
| assert isinstance(pipe, CustomMultiModelPipeline) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||