diff --git a/docs/source/develop.md b/docs/source/develop.md index 0d4f7f26..c048bef7 100644 --- a/docs/source/develop.md +++ b/docs/source/develop.md @@ -71,12 +71,18 @@ TODO * Feature ```shell [to #AONE_ID] feat: commit title + + Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8949062 + * commit msg1 * commit msg2 ``` * Bugfix ```shell [to #AONE_ID] fix: commit title + + Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8949062 + * commit msg1 * commit msg2 ``` diff --git a/maas_lib/models/base.py b/maas_lib/models/base.py index 10425f6c..efda1b3e 100644 --- a/maas_lib/models/base.py +++ b/maas_lib/models/base.py @@ -8,9 +8,9 @@ from maas_hub.file_download import model_file_download from maas_hub.snapshot_download import snapshot_download from maas_lib.models.builder import build_model -from maas_lib.pipelines import util from maas_lib.utils.config import Config from maas_lib.utils.constant import CONFIGFILE +from maas_lib.utils.hub import get_model_cache_dir Tensor = Union['torch.Tensor', 'tf.Tensor'] @@ -40,7 +40,7 @@ class Model(ABC): if osp.exists(model_name_or_path): local_model_dir = model_name_or_path else: - cache_path = util.get_model_cache_dir(model_name_or_path) + cache_path = get_model_cache_dir(model_name_or_path) local_model_dir = cache_path if osp.exists( cache_path) else snapshot_download(model_name_or_path) # else: diff --git a/maas_lib/pipelines/base.py b/maas_lib/pipelines/base.py index 76747b05..1ba8c36a 100644 --- a/maas_lib/pipelines/base.py +++ b/maas_lib/pipelines/base.py @@ -6,11 +6,11 @@ from typing import Any, Dict, Generator, List, Union from maas_hub.snapshot_download import snapshot_download -from maas_lib.models import Model -from maas_lib.pipelines import util +from maas_lib.models.base import Model from maas_lib.preprocessors import Preprocessor from maas_lib.pydatasets import PyDataset from maas_lib.utils.config import Config +from maas_lib.utils.hub import get_model_cache_dir from .util import is_model_name Tensor = Union['torch.Tensor', 'tf.Tensor'] @@ -26,7 +26,7 @@ class Pipeline(ABC): def initiate_single_model(self, model): if isinstance(model, str): if not osp.exists(model): - cache_path = util.get_model_cache_dir(model) + cache_path = get_model_cache_dir(model) model = cache_path if osp.exists( cache_path) else snapshot_download(model) return Model.from_pretrained(model) if is_model_name( diff --git a/maas_lib/pipelines/builder.py b/maas_lib/pipelines/builder.py index dd146cca..cd1eb32f 100644 --- a/maas_lib/pipelines/builder.py +++ b/maas_lib/pipelines/builder.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp -from typing import List, Union +from typing import Union import json from maas_hub.file_download import model_file_download @@ -10,7 +10,8 @@ from maas_lib.models.base import Model from maas_lib.utils.config import Config, ConfigDict from maas_lib.utils.constant import CONFIGFILE, Tasks from maas_lib.utils.registry import Registry, build_from_cfg -from .base import InputModel, Pipeline +from .base import Pipeline +from .default import DEFAULT_MODEL_FOR_PIPELINE, get_default_pipeline_info from .util import is_model_name PIPELINES = Registry('pipelines') @@ -32,7 +33,7 @@ def build_pipeline(cfg: ConfigDict, def pipeline(task: str = None, - model: Union[InputModel, List[InputModel]] = None, + model: Union[str, Model] = None, preprocessor=None, config_file: str = None, pipeline_name: str = None, @@ -67,23 +68,19 @@ def pipeline(task: str = None, if pipeline_name is None: # get default pipeline for this task - assert task in PIPELINES.modules, f'No pipeline is registered for Task {task}' - pipeline_name = get_default_pipeline(task) + pipeline_name, default_model_repo = get_default_pipeline_info(task) + if model is None: + model = default_model_repo + + assert isinstance(model, (type(None), str, Model)), \ + f'model should be either None, str or Model, but got {type(model)}' + + cfg = ConfigDict(type=pipeline_name, model=model) - cfg = ConfigDict(type=pipeline_name) if kwargs: cfg.update(kwargs) - if model: - assert isinstance(model, (str, Model, List)), \ - f'model should be either (list of) str or Model, but got {type(model)}' - cfg.model = model - if preprocessor is not None: cfg.preprocessor = preprocessor return build_pipeline(cfg, task_name=task) - - -def get_default_pipeline(task): - return list(PIPELINES.modules[task].keys())[0] diff --git a/maas_lib/pipelines/default.py b/maas_lib/pipelines/default.py new file mode 100644 index 00000000..5d364288 --- /dev/null +++ b/maas_lib/pipelines/default.py @@ -0,0 +1,48 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from maas_lib.utils.constant import Tasks + +DEFAULT_MODEL_FOR_PIPELINE = { + # TaskName: (pipeline_module_name, model_repo) + Tasks.image_matting: ('image-matting', 'damo/image-matting-person'), + Tasks.text_classification: + ('bert-sentiment-analysis', 'damo/bert-base-sst2'), + Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), + Tasks.image_captioning: ('ofa', None), +} + + +def add_default_pipeline_info(task: str, + model_name: str, + modelhub_name: str = None, + overwrite: bool = False): + """ Add default model for a task. + + Args: + task (str): task name. + model_name (str): model_name. + modelhub_name (str): name for default modelhub. + overwrite (bool): overwrite default info. + """ + if not overwrite: + assert task not in DEFAULT_MODEL_FOR_PIPELINE, \ + f'task {task} already has default model.' + + DEFAULT_MODEL_FOR_PIPELINE[task] = (model_name, modelhub_name) + + +def get_default_pipeline_info(task): + """ Get default info for certain task. + + Args: + task (str): task name. + + Return: + A tuple: first element is pipeline name(model_name), second element + is modelhub name. + """ + assert task in DEFAULT_MODEL_FOR_PIPELINE, \ + f'No default pipeline is registered for Task {task}' + + pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task] + return pipeline_name, default_model diff --git a/maas_lib/pipelines/multi_modal/image_captioning.py b/maas_lib/pipelines/multi_modal/image_captioning.py index 778354b7..2d8cc618 100644 --- a/maas_lib/pipelines/multi_modal/image_captioning.py +++ b/maas_lib/pipelines/multi_modal/image_captioning.py @@ -2,10 +2,6 @@ from typing import Any, Dict import numpy as np import torch -from fairseq import checkpoint_utils, tasks, utils -from ofa.models.ofa import OFAModel -from ofa.tasks.mm_tasks import CaptionTask -from ofa.utils.eval_utils import eval_caption from PIL import Image from maas_lib.pipelines.base import Input @@ -24,6 +20,8 @@ class ImageCaptionPipeline(Pipeline): def __init__(self, model: str, bpe_dir: str): super().__init__() # turn on cuda if GPU is available + from fairseq import checkpoint_utils, tasks, utils + from ofa.tasks.mm_tasks import CaptionTask tasks.register_task('caption', CaptionTask) use_cuda = False @@ -106,6 +104,8 @@ class ImageCaptionPipeline(Pipeline): return sample def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + from ofa.utils.eval_utils import eval_caption + results, _ = eval_caption(self.task, self.generator, self.models, input) return { diff --git a/maas_lib/pipelines/util.py b/maas_lib/pipelines/util.py index 4a0a28ec..771e0d2b 100644 --- a/maas_lib/pipelines/util.py +++ b/maas_lib/pipelines/util.py @@ -3,21 +3,11 @@ import os import os.path as osp import json -from maas_hub.constants import MODEL_ID_SEPARATOR from maas_hub.file_download import model_file_download from maas_lib.utils.constant import CONFIGFILE -# temp solution before the hub-cache is in place -def get_model_cache_dir(model_id: str, branch: str = 'master'): - model_id_expanded = model_id.replace('/', - MODEL_ID_SEPARATOR) + '.' + branch - default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas')) - return os.getenv('MAAS_CACHE', - os.path.join(default_cache_dir, 'hub', model_id_expanded)) - - def is_model_name(model): if osp.exists(model): if osp.exists(osp.join(model, CONFIGFILE)): diff --git a/maas_lib/utils/hub.py b/maas_lib/utils/hub.py new file mode 100644 index 00000000..2f61b148 --- /dev/null +++ b/maas_lib/utils/hub.py @@ -0,0 +1,14 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os + +from maas_hub.constants import MODEL_ID_SEPARATOR + + +# temp solution before the hub-cache is in place +def get_model_cache_dir(model_id: str, branch: str = 'master'): + model_id_expanded = model_id.replace('/', + MODEL_ID_SEPARATOR) + '.' + branch + default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas')) + return os.getenv('MAAS_CACHE', + os.path.join(default_cache_dir, 'hub', model_id_expanded)) diff --git a/maas_lib/utils/registry.py b/maas_lib/utils/registry.py index 838e6f83..bac3d616 100644 --- a/maas_lib/utils/registry.py +++ b/maas_lib/utils/registry.py @@ -100,6 +100,12 @@ class Registry(object): >>> class SwinTransformerDefaultGroup: >>> pass + >>> class SwinTransformer2: + >>> pass + >>> MODELS.register_module('image-classification', + module_name='SwinT2', + module_cls=SwinTransformer2) + Args: group_key: Group name of which module will be registered, default group name is 'default' diff --git a/tests/pipelines/test_base.py b/tests/pipelines/test_base.py index d523e7c4..5994ddde 100644 --- a/tests/pipelines/test_base.py +++ b/tests/pipelines/test_base.py @@ -8,6 +8,7 @@ import PIL from maas_lib.pipelines import Pipeline, pipeline from maas_lib.pipelines.builder import PIPELINES +from maas_lib.pipelines.default import add_default_pipeline_info from maas_lib.utils.constant import Tasks from maas_lib.utils.logger import get_logger from maas_lib.utils.registry import default_group @@ -75,6 +76,7 @@ class CustomPipelineTest(unittest.TestCase): return inputs self.assertTrue('custom-image' in PIPELINES.modules[default_group]) + add_default_pipeline_info(Tasks.image_tagging, 'custom-image') pipe = pipeline(pipeline_name='custom-image') pipe2 = pipeline(Tasks.image_tagging) self.assertTrue(type(pipe) is type(pipe2)) diff --git a/tests/pipelines/test_image_captioning.py b/tests/pipelines/test_image_captioning.py index f951f0a8..afcab01d 100644 --- a/tests/pipelines/test_image_captioning.py +++ b/tests/pipelines/test_image_captioning.py @@ -11,6 +11,7 @@ from maas_lib.utils.constant import Tasks class ImageCaptionTest(unittest.TestCase): + @unittest.skip('skip long test') def test_run(self): model = 'https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt' diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py index 8153b70d..25f19102 100644 --- a/tests/pipelines/test_image_matting.py +++ b/tests/pipelines/test_image_matting.py @@ -7,9 +7,10 @@ import unittest import cv2 from maas_lib.fileio import File -from maas_lib.pipelines import pipeline, util +from maas_lib.pipelines import pipeline from maas_lib.pydatasets import PyDataset from maas_lib.utils.constant import Tasks +from maas_lib.utils.hub import get_model_cache_dir class ImageMattingTest(unittest.TestCase): @@ -20,7 +21,7 @@ class ImageMattingTest(unittest.TestCase): purge_cache = True if purge_cache: shutil.rmtree( - util.get_model_cache_dir(self.model_id), ignore_errors=True) + get_model_cache_dir(self.model_id), ignore_errors=True) def test_run(self): model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ @@ -59,6 +60,15 @@ class ImageMattingTest(unittest.TestCase): cv2.imwrite('result.png', result['output_png']) print(f'Output written to {osp.abspath("result.png")}') + def test_run_modelhub_default_model(self): + img_matting = pipeline(Tasks.image_matting) + + result = img_matting( + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' + ) + cv2.imwrite('result.png', result['output_png']) + print(f'Output written to {osp.abspath("result.png")}') + if __name__ == '__main__': unittest.main() diff --git a/tests/pipelines/test_text_classification.py b/tests/pipelines/test_text_classification.py index b6528319..36285f80 100644 --- a/tests/pipelines/test_text_classification.py +++ b/tests/pipelines/test_text_classification.py @@ -7,10 +7,11 @@ from pathlib import Path from maas_lib.fileio import File from maas_lib.models import Model from maas_lib.models.nlp import BertForSequenceClassification -from maas_lib.pipelines import SequenceClassificationPipeline, pipeline, util +from maas_lib.pipelines import SequenceClassificationPipeline, pipeline from maas_lib.preprocessors import SequenceClassificationPreprocessor from maas_lib.pydatasets import PyDataset from maas_lib.utils.constant import Tasks +from maas_lib.utils.hub import get_model_cache_dir class SequenceClassificationTest(unittest.TestCase): @@ -21,7 +22,7 @@ class SequenceClassificationTest(unittest.TestCase): purge_cache = True if purge_cache: shutil.rmtree( - util.get_model_cache_dir(self.model_id), ignore_errors=True) + get_model_cache_dir(self.model_id), ignore_errors=True) def predict(self, pipeline_ins: SequenceClassificationPipeline): from easynlp.appzoo import load_dataset @@ -83,6 +84,12 @@ class SequenceClassificationTest(unittest.TestCase): PyDataset.load('glue', name='sst2', target='sentence')) self.printDataset(result) + def test_run_with_default_model(self): + text_classification = pipeline(task=Tasks.text_classification) + result = text_classification( + PyDataset.load('glue', name='sst2', target='sentence')) + self.printDataset(result) + def test_run_with_dataset(self): model = Model.from_pretrained(self.model_id) preprocessor = SequenceClassificationPreprocessor( diff --git a/tests/pipelines/test_text_generation.py b/tests/pipelines/test_text_generation.py index d59fdabb..235279c2 100644 --- a/tests/pipelines/test_text_generation.py +++ b/tests/pipelines/test_text_generation.py @@ -15,6 +15,7 @@ class TextGenerationTest(unittest.TestCase): input1 = "今日天气类型='晴'&温度变化趋势='大幅上升'&最低气温='28℃'&最高气温='31℃'&体感='湿热'" input2 = "今日天气类型='多云'&体感='舒适'&最低气温='26℃'&最高气温='30℃'" + @unittest.skip('skip temporarily to save test time') def test_run(self): cache_path = snapshot_download(self.model_id) preprocessor = TextGenerationPreprocessor( @@ -41,6 +42,10 @@ class TextGenerationTest(unittest.TestCase): task=Tasks.text_generation, model=self.model_id) print(pipeline_ins(self.input2)) + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.text_generation) + print(pipeline_ins(self.input2)) + if __name__ == '__main__': unittest.main()