Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8972440 * [to #42339559] support multiple modelsmaster
| @@ -2,4 +2,4 @@ | |||
| from .base import Model | |||
| from .builder import MODELS, build_model | |||
| from .nlp import SequenceClassificationModel | |||
| from .nlp import BertForSequenceClassification | |||
| @@ -50,7 +50,7 @@ class Model(ABC): | |||
| cfg = Config.from_file(osp.join(local_model_dir, CONFIGFILE)) | |||
| task_name = cfg.task | |||
| model_cfg = cfg.model | |||
| # TODO @wenmeng.zwm may should mannually initialize model after model building | |||
| # TODO @wenmeng.zwm may should manually initialize model after model building | |||
| if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): | |||
| model_cfg.type = model_cfg.model_type | |||
| model_cfg.model_dir = local_model_dir | |||
| @@ -6,12 +6,12 @@ from maas_lib.utils.constant import Tasks | |||
| from ..base import Model | |||
| from ..builder import MODELS | |||
| __all__ = ['SequenceClassificationModel'] | |||
| __all__ = ['BertForSequenceClassification'] | |||
| @MODELS.register_module( | |||
| Tasks.text_classification, module_name=r'bert-sentiment-analysis') | |||
| class SequenceClassificationModel(Model): | |||
| class BertForSequenceClassification(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| # Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs) | |||
| @@ -15,6 +15,7 @@ from .util import is_model_name | |||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
| Input = Union[str, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] | |||
| InputModel = Union[str, Model] | |||
| output_keys = [ | |||
| ] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key | |||
| @@ -22,10 +23,32 @@ output_keys = [ | |||
| class Pipeline(ABC): | |||
| def initiate_single_model(self, model): | |||
| if isinstance(model, str): | |||
| if not osp.exists(model): | |||
| cache_path = util.get_model_cache_dir(model) | |||
| model = cache_path if osp.exists( | |||
| cache_path) else snapshot_download(model) | |||
| return Model.from_pretrained(model) if is_model_name( | |||
| model) else model | |||
| elif isinstance(model, Model): | |||
| return model | |||
| else: | |||
| if model: | |||
| raise ValueError( | |||
| f'model type for single model is either str or Model, but got type {type(model)}' | |||
| ) | |||
| def initiate_multiple_models(self, input_models: List[InputModel]): | |||
| models = [] | |||
| for model in input_models: | |||
| models.append(self.initiate_single_model(model)) | |||
| return models | |||
| def __init__(self, | |||
| config_file: str = None, | |||
| model: Union[Model, str] = None, | |||
| preprocessor: Preprocessor = None, | |||
| model: Union[InputModel, List[InputModel]] = None, | |||
| preprocessor: Union[Preprocessor, List[Preprocessor]] = None, | |||
| **kwargs): | |||
| """ Base class for pipeline. | |||
| @@ -35,31 +58,18 @@ class Pipeline(ABC): | |||
| Args: | |||
| config_file(str, optional): Filepath to configuration file. | |||
| model: Model name or model object | |||
| preprocessor: Preprocessor object | |||
| model: (list of) Model name or model object | |||
| preprocessor: (list of) Preprocessor object | |||
| """ | |||
| if config_file is not None: | |||
| self.cfg = Config.from_file(config_file) | |||
| if isinstance(model, str): | |||
| if not osp.exists(model): | |||
| cache_path = util.get_model_cache_dir(model) | |||
| if osp.exists(cache_path): | |||
| model = cache_path | |||
| else: | |||
| model = snapshot_download(model) | |||
| if is_model_name(model): | |||
| self.model = Model.from_pretrained(model) | |||
| else: | |||
| self.model = model | |||
| elif isinstance(model, Model): | |||
| self.model = model | |||
| if not isinstance(model, List): | |||
| self.model = self.initiate_single_model(model) | |||
| self.models = [self.model] | |||
| else: | |||
| if model: | |||
| raise ValueError( | |||
| f'model type is either str or Model, but got type {type(model)}' | |||
| ) | |||
| self.models = self.initiate_multiple_models(model) | |||
| self.has_multiple_models = len(self.models) > 1 | |||
| self.preprocessor = preprocessor | |||
| def __call__(self, input: Union[Input, List[Input]], *args, | |||
| @@ -94,15 +104,17 @@ class Pipeline(ABC): | |||
| def preprocess(self, inputs: Input) -> Dict[str, Any]: | |||
| """ Provide default implementation based on preprocess_cfg and user can reimplement it | |||
| """ | |||
| assert self.preprocessor is not None, 'preprocess method should be implemented' | |||
| assert not isinstance(self.preprocessor, List),\ | |||
| 'default implementation does not support using multiple preprocessors.' | |||
| return self.preprocessor(inputs) | |||
| def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| """ Provide default implementation using self.model and user can reimplement it | |||
| """ | |||
| assert self.model is not None, 'forward method should be implemented' | |||
| assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.' | |||
| return self.model(inputs) | |||
| @abstractmethod | |||
| @@ -1,7 +1,7 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| from typing import Union | |||
| from typing import List, Union | |||
| import json | |||
| from maas_hub.file_download import model_file_download | |||
| @@ -10,7 +10,7 @@ from maas_lib.models.base import Model | |||
| from maas_lib.utils.config import Config, ConfigDict | |||
| from maas_lib.utils.constant import CONFIGFILE, Tasks | |||
| from maas_lib.utils.registry import Registry, build_from_cfg | |||
| from .base import Pipeline | |||
| from .base import InputModel, Pipeline | |||
| from .util import is_model_name | |||
| PIPELINES = Registry('pipelines') | |||
| @@ -32,7 +32,7 @@ def build_pipeline(cfg: ConfigDict, | |||
| def pipeline(task: str = None, | |||
| model: Union[str, Model] = None, | |||
| model: Union[InputModel, List[InputModel]] = None, | |||
| preprocessor=None, | |||
| config_file: str = None, | |||
| pipeline_name: str = None, | |||
| @@ -75,8 +75,8 @@ def pipeline(task: str = None, | |||
| cfg.update(kwargs) | |||
| if model: | |||
| assert isinstance(model, (str, Model)), \ | |||
| f'model should be either str or Model, but got {type(model)}' | |||
| assert isinstance(model, (str, Model, List)), \ | |||
| f'model should be either (list of) str or Model, but got {type(model)}' | |||
| cfg.model = model | |||
| if preprocessor is not None: | |||
| @@ -1 +1 @@ | |||
| from .image_matting import ImageMatting | |||
| from .image_matting_pipeline import ImageMattingPipeline | |||
| @@ -4,7 +4,6 @@ from typing import Any, Dict, List, Tuple, Union | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| from cv2 import COLOR_GRAY2RGB | |||
| from maas_lib.pipelines.base import Input | |||
| from maas_lib.preprocessors import load_image | |||
| @@ -18,7 +17,7 @@ logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_matting, module_name=Tasks.image_matting) | |||
| class ImageMatting(Pipeline): | |||
| class ImageMattingPipeline(Pipeline): | |||
| def __init__(self, model: str): | |||
| super().__init__(model=model) | |||
| @@ -5,7 +5,7 @@ from typing import Any, Dict, Union | |||
| import json | |||
| import numpy as np | |||
| from maas_lib.models.nlp import SequenceClassificationModel | |||
| from maas_lib.models.nlp import BertForSequenceClassification | |||
| from maas_lib.preprocessors import SequenceClassificationPreprocessor | |||
| from maas_lib.utils.constant import Tasks | |||
| from ...models import Model | |||
| @@ -20,18 +20,20 @@ __all__ = ['SequenceClassificationPipeline'] | |||
| class SequenceClassificationPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[SequenceClassificationModel, str], | |||
| model: Union[BertForSequenceClassification, str], | |||
| preprocessor: SequenceClassificationPreprocessor = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
| Args: | |||
| model (SequenceClassificationModel): a model instance | |||
| model (BertForSequenceClassification): a model instance | |||
| preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||
| """ | |||
| assert isinstance(model, str) or isinstance(model, BertForSequenceClassification), \ | |||
| 'model must be a single str or BertForSequenceClassification' | |||
| sc_model = model if isinstance( | |||
| model, | |||
| SequenceClassificationModel) else Model.from_pretrained(model) | |||
| BertForSequenceClassification) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| sc_model.model_dir, | |||
| @@ -1,5 +1,4 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| import shutil | |||
| import tempfile | |||
| @@ -9,7 +9,7 @@ from pydatasets import PyDataset | |||
| from maas_lib.fileio import File | |||
| from maas_lib.models import Model | |||
| from maas_lib.models.nlp import SequenceClassificationModel | |||
| from maas_lib.models.nlp import BertForSequenceClassification | |||
| from maas_lib.pipelines import SequenceClassificationPipeline, pipeline, util | |||
| from maas_lib.preprocessors import SequenceClassificationPreprocessor | |||
| from maas_lib.utils.constant import Tasks | |||
| @@ -59,7 +59,7 @@ class SequenceClassificationTest(unittest.TestCase): | |||
| with zipfile.ZipFile(cache_path_str, 'r') as zipf: | |||
| zipf.extractall(cache_path.parent) | |||
| path = r'.cache/easynlp/' | |||
| model = SequenceClassificationModel(path) | |||
| model = BertForSequenceClassification(path) | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| path, first_sequence='sentence', second_sequence=None) | |||
| pipeline1 = SequenceClassificationPipeline(model, preprocessor) | |||