Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8972440 * [to #42339559] support multiple modelsmaster
| @@ -2,4 +2,4 @@ | |||||
| from .base import Model | from .base import Model | ||||
| from .builder import MODELS, build_model | from .builder import MODELS, build_model | ||||
| from .nlp import SequenceClassificationModel | |||||
| from .nlp import BertForSequenceClassification | |||||
| @@ -50,7 +50,7 @@ class Model(ABC): | |||||
| cfg = Config.from_file(osp.join(local_model_dir, CONFIGFILE)) | cfg = Config.from_file(osp.join(local_model_dir, CONFIGFILE)) | ||||
| task_name = cfg.task | task_name = cfg.task | ||||
| model_cfg = cfg.model | model_cfg = cfg.model | ||||
| # TODO @wenmeng.zwm may should mannually initialize model after model building | |||||
| # TODO @wenmeng.zwm may should manually initialize model after model building | |||||
| if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): | if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): | ||||
| model_cfg.type = model_cfg.model_type | model_cfg.type = model_cfg.model_type | ||||
| model_cfg.model_dir = local_model_dir | model_cfg.model_dir = local_model_dir | ||||
| @@ -6,12 +6,12 @@ from maas_lib.utils.constant import Tasks | |||||
| from ..base import Model | from ..base import Model | ||||
| from ..builder import MODELS | from ..builder import MODELS | ||||
| __all__ = ['SequenceClassificationModel'] | |||||
| __all__ = ['BertForSequenceClassification'] | |||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| Tasks.text_classification, module_name=r'bert-sentiment-analysis') | Tasks.text_classification, module_name=r'bert-sentiment-analysis') | ||||
| class SequenceClassificationModel(Model): | |||||
| class BertForSequenceClassification(Model): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | def __init__(self, model_dir: str, *args, **kwargs): | ||||
| # Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs) | # Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs) | ||||
| @@ -15,6 +15,7 @@ from .util import is_model_name | |||||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | Tensor = Union['torch.Tensor', 'tf.Tensor'] | ||||
| Input = Union[str, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] | Input = Union[str, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] | ||||
| InputModel = Union[str, Model] | |||||
| output_keys = [ | output_keys = [ | ||||
| ] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key | ] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key | ||||
| @@ -22,10 +23,32 @@ output_keys = [ | |||||
| class Pipeline(ABC): | class Pipeline(ABC): | ||||
| def initiate_single_model(self, model): | |||||
| if isinstance(model, str): | |||||
| if not osp.exists(model): | |||||
| cache_path = util.get_model_cache_dir(model) | |||||
| model = cache_path if osp.exists( | |||||
| cache_path) else snapshot_download(model) | |||||
| return Model.from_pretrained(model) if is_model_name( | |||||
| model) else model | |||||
| elif isinstance(model, Model): | |||||
| return model | |||||
| else: | |||||
| if model: | |||||
| raise ValueError( | |||||
| f'model type for single model is either str or Model, but got type {type(model)}' | |||||
| ) | |||||
| def initiate_multiple_models(self, input_models: List[InputModel]): | |||||
| models = [] | |||||
| for model in input_models: | |||||
| models.append(self.initiate_single_model(model)) | |||||
| return models | |||||
| def __init__(self, | def __init__(self, | ||||
| config_file: str = None, | config_file: str = None, | ||||
| model: Union[Model, str] = None, | |||||
| preprocessor: Preprocessor = None, | |||||
| model: Union[InputModel, List[InputModel]] = None, | |||||
| preprocessor: Union[Preprocessor, List[Preprocessor]] = None, | |||||
| **kwargs): | **kwargs): | ||||
| """ Base class for pipeline. | """ Base class for pipeline. | ||||
| @@ -35,31 +58,18 @@ class Pipeline(ABC): | |||||
| Args: | Args: | ||||
| config_file(str, optional): Filepath to configuration file. | config_file(str, optional): Filepath to configuration file. | ||||
| model: Model name or model object | |||||
| preprocessor: Preprocessor object | |||||
| model: (list of) Model name or model object | |||||
| preprocessor: (list of) Preprocessor object | |||||
| """ | """ | ||||
| if config_file is not None: | if config_file is not None: | ||||
| self.cfg = Config.from_file(config_file) | self.cfg = Config.from_file(config_file) | ||||
| if isinstance(model, str): | |||||
| if not osp.exists(model): | |||||
| cache_path = util.get_model_cache_dir(model) | |||||
| if osp.exists(cache_path): | |||||
| model = cache_path | |||||
| else: | |||||
| model = snapshot_download(model) | |||||
| if is_model_name(model): | |||||
| self.model = Model.from_pretrained(model) | |||||
| else: | |||||
| self.model = model | |||||
| elif isinstance(model, Model): | |||||
| self.model = model | |||||
| if not isinstance(model, List): | |||||
| self.model = self.initiate_single_model(model) | |||||
| self.models = [self.model] | |||||
| else: | else: | ||||
| if model: | |||||
| raise ValueError( | |||||
| f'model type is either str or Model, but got type {type(model)}' | |||||
| ) | |||||
| self.models = self.initiate_multiple_models(model) | |||||
| self.has_multiple_models = len(self.models) > 1 | |||||
| self.preprocessor = preprocessor | self.preprocessor = preprocessor | ||||
| def __call__(self, input: Union[Input, List[Input]], *args, | def __call__(self, input: Union[Input, List[Input]], *args, | ||||
| @@ -94,15 +104,17 @@ class Pipeline(ABC): | |||||
| def preprocess(self, inputs: Input) -> Dict[str, Any]: | def preprocess(self, inputs: Input) -> Dict[str, Any]: | ||||
| """ Provide default implementation based on preprocess_cfg and user can reimplement it | """ Provide default implementation based on preprocess_cfg and user can reimplement it | ||||
| """ | """ | ||||
| assert self.preprocessor is not None, 'preprocess method should be implemented' | assert self.preprocessor is not None, 'preprocess method should be implemented' | ||||
| assert not isinstance(self.preprocessor, List),\ | |||||
| 'default implementation does not support using multiple preprocessors.' | |||||
| return self.preprocessor(inputs) | return self.preprocessor(inputs) | ||||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| """ Provide default implementation using self.model and user can reimplement it | """ Provide default implementation using self.model and user can reimplement it | ||||
| """ | """ | ||||
| assert self.model is not None, 'forward method should be implemented' | assert self.model is not None, 'forward method should be implemented' | ||||
| assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.' | |||||
| return self.model(inputs) | return self.model(inputs) | ||||
| @abstractmethod | @abstractmethod | ||||
| @@ -1,7 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os.path as osp | import os.path as osp | ||||
| from typing import Union | |||||
| from typing import List, Union | |||||
| import json | import json | ||||
| from maas_hub.file_download import model_file_download | from maas_hub.file_download import model_file_download | ||||
| @@ -10,7 +10,7 @@ from maas_lib.models.base import Model | |||||
| from maas_lib.utils.config import Config, ConfigDict | from maas_lib.utils.config import Config, ConfigDict | ||||
| from maas_lib.utils.constant import CONFIGFILE, Tasks | from maas_lib.utils.constant import CONFIGFILE, Tasks | ||||
| from maas_lib.utils.registry import Registry, build_from_cfg | from maas_lib.utils.registry import Registry, build_from_cfg | ||||
| from .base import Pipeline | |||||
| from .base import InputModel, Pipeline | |||||
| from .util import is_model_name | from .util import is_model_name | ||||
| PIPELINES = Registry('pipelines') | PIPELINES = Registry('pipelines') | ||||
| @@ -32,7 +32,7 @@ def build_pipeline(cfg: ConfigDict, | |||||
| def pipeline(task: str = None, | def pipeline(task: str = None, | ||||
| model: Union[str, Model] = None, | |||||
| model: Union[InputModel, List[InputModel]] = None, | |||||
| preprocessor=None, | preprocessor=None, | ||||
| config_file: str = None, | config_file: str = None, | ||||
| pipeline_name: str = None, | pipeline_name: str = None, | ||||
| @@ -75,8 +75,8 @@ def pipeline(task: str = None, | |||||
| cfg.update(kwargs) | cfg.update(kwargs) | ||||
| if model: | if model: | ||||
| assert isinstance(model, (str, Model)), \ | |||||
| f'model should be either str or Model, but got {type(model)}' | |||||
| assert isinstance(model, (str, Model, List)), \ | |||||
| f'model should be either (list of) str or Model, but got {type(model)}' | |||||
| cfg.model = model | cfg.model = model | ||||
| if preprocessor is not None: | if preprocessor is not None: | ||||
| @@ -1 +1 @@ | |||||
| from .image_matting import ImageMatting | |||||
| from .image_matting_pipeline import ImageMattingPipeline | |||||
| @@ -4,7 +4,6 @@ from typing import Any, Dict, List, Tuple, Union | |||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| import PIL | import PIL | ||||
| from cv2 import COLOR_GRAY2RGB | |||||
| from maas_lib.pipelines.base import Input | from maas_lib.pipelines.base import Input | ||||
| from maas_lib.preprocessors import load_image | from maas_lib.preprocessors import load_image | ||||
| @@ -18,7 +17,7 @@ logger = get_logger() | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.image_matting, module_name=Tasks.image_matting) | Tasks.image_matting, module_name=Tasks.image_matting) | ||||
| class ImageMatting(Pipeline): | |||||
| class ImageMattingPipeline(Pipeline): | |||||
| def __init__(self, model: str): | def __init__(self, model: str): | ||||
| super().__init__(model=model) | super().__init__(model=model) | ||||
| @@ -5,7 +5,7 @@ from typing import Any, Dict, Union | |||||
| import json | import json | ||||
| import numpy as np | import numpy as np | ||||
| from maas_lib.models.nlp import SequenceClassificationModel | |||||
| from maas_lib.models.nlp import BertForSequenceClassification | |||||
| from maas_lib.preprocessors import SequenceClassificationPreprocessor | from maas_lib.preprocessors import SequenceClassificationPreprocessor | ||||
| from maas_lib.utils.constant import Tasks | from maas_lib.utils.constant import Tasks | ||||
| from ...models import Model | from ...models import Model | ||||
| @@ -20,18 +20,20 @@ __all__ = ['SequenceClassificationPipeline'] | |||||
| class SequenceClassificationPipeline(Pipeline): | class SequenceClassificationPipeline(Pipeline): | ||||
| def __init__(self, | def __init__(self, | ||||
| model: Union[SequenceClassificationModel, str], | |||||
| model: Union[BertForSequenceClassification, str], | |||||
| preprocessor: SequenceClassificationPreprocessor = None, | preprocessor: SequenceClassificationPreprocessor = None, | ||||
| **kwargs): | **kwargs): | ||||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | ||||
| Args: | Args: | ||||
| model (SequenceClassificationModel): a model instance | |||||
| model (BertForSequenceClassification): a model instance | |||||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | ||||
| """ | """ | ||||
| assert isinstance(model, str) or isinstance(model, BertForSequenceClassification), \ | |||||
| 'model must be a single str or BertForSequenceClassification' | |||||
| sc_model = model if isinstance( | sc_model = model if isinstance( | ||||
| model, | model, | ||||
| SequenceClassificationModel) else Model.from_pretrained(model) | |||||
| BertForSequenceClassification) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = SequenceClassificationPreprocessor( | preprocessor = SequenceClassificationPreprocessor( | ||||
| sc_model.model_dir, | sc_model.model_dir, | ||||
| @@ -1,5 +1,4 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | |||||
| import os.path as osp | import os.path as osp | ||||
| import shutil | import shutil | ||||
| import tempfile | import tempfile | ||||
| @@ -9,7 +9,7 @@ from pydatasets import PyDataset | |||||
| from maas_lib.fileio import File | from maas_lib.fileio import File | ||||
| from maas_lib.models import Model | from maas_lib.models import Model | ||||
| from maas_lib.models.nlp import SequenceClassificationModel | |||||
| from maas_lib.models.nlp import BertForSequenceClassification | |||||
| from maas_lib.pipelines import SequenceClassificationPipeline, pipeline, util | from maas_lib.pipelines import SequenceClassificationPipeline, pipeline, util | ||||
| from maas_lib.preprocessors import SequenceClassificationPreprocessor | from maas_lib.preprocessors import SequenceClassificationPreprocessor | ||||
| from maas_lib.utils.constant import Tasks | from maas_lib.utils.constant import Tasks | ||||
| @@ -59,7 +59,7 @@ class SequenceClassificationTest(unittest.TestCase): | |||||
| with zipfile.ZipFile(cache_path_str, 'r') as zipf: | with zipfile.ZipFile(cache_path_str, 'r') as zipf: | ||||
| zipf.extractall(cache_path.parent) | zipf.extractall(cache_path.parent) | ||||
| path = r'.cache/easynlp/' | path = r'.cache/easynlp/' | ||||
| model = SequenceClassificationModel(path) | |||||
| model = BertForSequenceClassification(path) | |||||
| preprocessor = SequenceClassificationPreprocessor( | preprocessor = SequenceClassificationPreprocessor( | ||||
| path, first_sequence='sentence', second_sequence=None) | path, first_sequence='sentence', second_sequence=None) | ||||
| pipeline1 = SequenceClassificationPipeline(model, preprocessor) | pipeline1 = SequenceClassificationPipeline(model, preprocessor) | ||||