diff --git a/maas_lib/models/__init__.py b/maas_lib/models/__init__.py index aa1b3f14..170e525e 100644 --- a/maas_lib/models/__init__.py +++ b/maas_lib/models/__init__.py @@ -2,4 +2,4 @@ from .base import Model from .builder import MODELS, build_model -from .nlp import SequenceClassificationModel +from .nlp import BertForSequenceClassification diff --git a/maas_lib/models/base.py b/maas_lib/models/base.py index 677a136a..10425f6c 100644 --- a/maas_lib/models/base.py +++ b/maas_lib/models/base.py @@ -50,7 +50,7 @@ class Model(ABC): cfg = Config.from_file(osp.join(local_model_dir, CONFIGFILE)) task_name = cfg.task model_cfg = cfg.model - # TODO @wenmeng.zwm may should mannually initialize model after model building + # TODO @wenmeng.zwm may should manually initialize model after model building if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): model_cfg.type = model_cfg.model_type model_cfg.model_dir = local_model_dir diff --git a/maas_lib/models/nlp/sequence_classification_model.py b/maas_lib/models/nlp/sequence_classification_model.py index f77b0fbc..0afdf26e 100644 --- a/maas_lib/models/nlp/sequence_classification_model.py +++ b/maas_lib/models/nlp/sequence_classification_model.py @@ -6,12 +6,12 @@ from maas_lib.utils.constant import Tasks from ..base import Model from ..builder import MODELS -__all__ = ['SequenceClassificationModel'] +__all__ = ['BertForSequenceClassification'] @MODELS.register_module( Tasks.text_classification, module_name=r'bert-sentiment-analysis') -class SequenceClassificationModel(Model): +class BertForSequenceClassification(Model): def __init__(self, model_dir: str, *args, **kwargs): # Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs) diff --git a/maas_lib/pipelines/base.py b/maas_lib/pipelines/base.py index 5e387c62..47c6d90b 100644 --- a/maas_lib/pipelines/base.py +++ b/maas_lib/pipelines/base.py @@ -15,6 +15,7 @@ from .util import is_model_name Tensor = Union['torch.Tensor', 'tf.Tensor'] Input = Union[str, PyDataset, 'PIL.Image.Image', 'numpy.ndarray'] +InputModel = Union[str, Model] output_keys = [ ] # 对于不同task的pipeline,规定标准化的输出key,用以对接postprocess,同时也用来标准化postprocess后输出的key @@ -22,10 +23,32 @@ output_keys = [ class Pipeline(ABC): + def initiate_single_model(self, model): + if isinstance(model, str): + if not osp.exists(model): + cache_path = util.get_model_cache_dir(model) + model = cache_path if osp.exists( + cache_path) else snapshot_download(model) + return Model.from_pretrained(model) if is_model_name( + model) else model + elif isinstance(model, Model): + return model + else: + if model: + raise ValueError( + f'model type for single model is either str or Model, but got type {type(model)}' + ) + + def initiate_multiple_models(self, input_models: List[InputModel]): + models = [] + for model in input_models: + models.append(self.initiate_single_model(model)) + return models + def __init__(self, config_file: str = None, - model: Union[Model, str] = None, - preprocessor: Preprocessor = None, + model: Union[InputModel, List[InputModel]] = None, + preprocessor: Union[Preprocessor, List[Preprocessor]] = None, **kwargs): """ Base class for pipeline. @@ -35,31 +58,18 @@ class Pipeline(ABC): Args: config_file(str, optional): Filepath to configuration file. - model: Model name or model object - preprocessor: Preprocessor object + model: (list of) Model name or model object + preprocessor: (list of) Preprocessor object """ if config_file is not None: self.cfg = Config.from_file(config_file) - - if isinstance(model, str): - if not osp.exists(model): - cache_path = util.get_model_cache_dir(model) - if osp.exists(cache_path): - model = cache_path - else: - model = snapshot_download(model) - - if is_model_name(model): - self.model = Model.from_pretrained(model) - else: - self.model = model - elif isinstance(model, Model): - self.model = model + if not isinstance(model, List): + self.model = self.initiate_single_model(model) + self.models = [self.model] else: - if model: - raise ValueError( - f'model type is either str or Model, but got type {type(model)}' - ) + self.models = self.initiate_multiple_models(model) + + self.has_multiple_models = len(self.models) > 1 self.preprocessor = preprocessor def __call__(self, input: Union[Input, List[Input]], *args, @@ -94,15 +104,17 @@ class Pipeline(ABC): def preprocess(self, inputs: Input) -> Dict[str, Any]: """ Provide default implementation based on preprocess_cfg and user can reimplement it - """ assert self.preprocessor is not None, 'preprocess method should be implemented' + assert not isinstance(self.preprocessor, List),\ + 'default implementation does not support using multiple preprocessors.' return self.preprocessor(inputs) def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """ Provide default implementation using self.model and user can reimplement it """ assert self.model is not None, 'forward method should be implemented' + assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.' return self.model(inputs) @abstractmethod diff --git a/maas_lib/pipelines/builder.py b/maas_lib/pipelines/builder.py index acaccf05..dd146cca 100644 --- a/maas_lib/pipelines/builder.py +++ b/maas_lib/pipelines/builder.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp -from typing import Union +from typing import List, Union import json from maas_hub.file_download import model_file_download @@ -10,7 +10,7 @@ from maas_lib.models.base import Model from maas_lib.utils.config import Config, ConfigDict from maas_lib.utils.constant import CONFIGFILE, Tasks from maas_lib.utils.registry import Registry, build_from_cfg -from .base import Pipeline +from .base import InputModel, Pipeline from .util import is_model_name PIPELINES = Registry('pipelines') @@ -32,7 +32,7 @@ def build_pipeline(cfg: ConfigDict, def pipeline(task: str = None, - model: Union[str, Model] = None, + model: Union[InputModel, List[InputModel]] = None, preprocessor=None, config_file: str = None, pipeline_name: str = None, @@ -75,8 +75,8 @@ def pipeline(task: str = None, cfg.update(kwargs) if model: - assert isinstance(model, (str, Model)), \ - f'model should be either str or Model, but got {type(model)}' + assert isinstance(model, (str, Model, List)), \ + f'model should be either (list of) str or Model, but got {type(model)}' cfg.model = model if preprocessor is not None: diff --git a/maas_lib/pipelines/cv/__init__.py b/maas_lib/pipelines/cv/__init__.py index 79548682..6f877a26 100644 --- a/maas_lib/pipelines/cv/__init__.py +++ b/maas_lib/pipelines/cv/__init__.py @@ -1 +1 @@ -from .image_matting import ImageMatting +from .image_matting_pipeline import ImageMattingPipeline diff --git a/maas_lib/pipelines/cv/image_matting.py b/maas_lib/pipelines/cv/image_matting_pipeline.py similarity index 97% rename from maas_lib/pipelines/cv/image_matting.py rename to maas_lib/pipelines/cv/image_matting_pipeline.py index fdb443f9..0317b4bd 100644 --- a/maas_lib/pipelines/cv/image_matting.py +++ b/maas_lib/pipelines/cv/image_matting_pipeline.py @@ -4,7 +4,6 @@ from typing import Any, Dict, List, Tuple, Union import cv2 import numpy as np import PIL -from cv2 import COLOR_GRAY2RGB from maas_lib.pipelines.base import Input from maas_lib.preprocessors import load_image @@ -18,7 +17,7 @@ logger = get_logger() @PIPELINES.register_module( Tasks.image_matting, module_name=Tasks.image_matting) -class ImageMatting(Pipeline): +class ImageMattingPipeline(Pipeline): def __init__(self, model: str): super().__init__(model=model) diff --git a/maas_lib/pipelines/nlp/sequence_classification_pipeline.py b/maas_lib/pipelines/nlp/sequence_classification_pipeline.py index 9300035d..014eb4a3 100644 --- a/maas_lib/pipelines/nlp/sequence_classification_pipeline.py +++ b/maas_lib/pipelines/nlp/sequence_classification_pipeline.py @@ -5,7 +5,7 @@ from typing import Any, Dict, Union import json import numpy as np -from maas_lib.models.nlp import SequenceClassificationModel +from maas_lib.models.nlp import BertForSequenceClassification from maas_lib.preprocessors import SequenceClassificationPreprocessor from maas_lib.utils.constant import Tasks from ...models import Model @@ -20,18 +20,20 @@ __all__ = ['SequenceClassificationPipeline'] class SequenceClassificationPipeline(Pipeline): def __init__(self, - model: Union[SequenceClassificationModel, str], + model: Union[BertForSequenceClassification, str], preprocessor: SequenceClassificationPreprocessor = None, **kwargs): """use `model` and `preprocessor` to create a nlp text classification pipeline for prediction Args: - model (SequenceClassificationModel): a model instance + model (BertForSequenceClassification): a model instance preprocessor (SequenceClassificationPreprocessor): a preprocessor instance """ + assert isinstance(model, str) or isinstance(model, BertForSequenceClassification), \ + 'model must be a single str or BertForSequenceClassification' sc_model = model if isinstance( model, - SequenceClassificationModel) else Model.from_pretrained(model) + BertForSequenceClassification) else Model.from_pretrained(model) if preprocessor is None: preprocessor = SequenceClassificationPreprocessor( sc_model.model_dir, diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py index 4fb475bb..33e8c28c 100644 --- a/tests/pipelines/test_image_matting.py +++ b/tests/pipelines/test_image_matting.py @@ -1,5 +1,4 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import os import os.path as osp import shutil import tempfile diff --git a/tests/pipelines/test_text_classification.py b/tests/pipelines/test_text_classification.py index 2db7e67f..f599b205 100644 --- a/tests/pipelines/test_text_classification.py +++ b/tests/pipelines/test_text_classification.py @@ -9,7 +9,7 @@ from pydatasets import PyDataset from maas_lib.fileio import File from maas_lib.models import Model -from maas_lib.models.nlp import SequenceClassificationModel +from maas_lib.models.nlp import BertForSequenceClassification from maas_lib.pipelines import SequenceClassificationPipeline, pipeline, util from maas_lib.preprocessors import SequenceClassificationPreprocessor from maas_lib.utils.constant import Tasks @@ -59,7 +59,7 @@ class SequenceClassificationTest(unittest.TestCase): with zipfile.ZipFile(cache_path_str, 'r') as zipf: zipf.extractall(cache_path.parent) path = r'.cache/easynlp/' - model = SequenceClassificationModel(path) + model = BertForSequenceClassification(path) preprocessor = SequenceClassificationPreprocessor( path, first_sequence='sentence', second_sequence=None) pipeline1 = SequenceClassificationPipeline(model, preprocessor)