1. add default model support 2. fix circular import 3. temporarily skip ofa and palm test which costs too much time Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8981076master
| @@ -71,12 +71,18 @@ TODO | |||
| * Feature | |||
| ```shell | |||
| [to #AONE_ID] feat: commit title | |||
| Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8949062 | |||
| * commit msg1 | |||
| * commit msg2 | |||
| ``` | |||
| * Bugfix | |||
| ```shell | |||
| [to #AONE_ID] fix: commit title | |||
| Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8949062 | |||
| * commit msg1 | |||
| * commit msg2 | |||
| ``` | |||
| @@ -8,9 +8,9 @@ from maas_hub.file_download import model_file_download | |||
| from maas_hub.snapshot_download import snapshot_download | |||
| from maas_lib.models.builder import build_model | |||
| from maas_lib.pipelines import util | |||
| from maas_lib.utils.config import Config | |||
| from maas_lib.utils.constant import CONFIGFILE | |||
| from maas_lib.utils.hub import get_model_cache_dir | |||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
| @@ -40,7 +40,7 @@ class Model(ABC): | |||
| if osp.exists(model_name_or_path): | |||
| local_model_dir = model_name_or_path | |||
| else: | |||
| cache_path = util.get_model_cache_dir(model_name_or_path) | |||
| cache_path = get_model_cache_dir(model_name_or_path) | |||
| local_model_dir = cache_path if osp.exists( | |||
| cache_path) else snapshot_download(model_name_or_path) | |||
| # else: | |||
| @@ -6,11 +6,11 @@ from typing import Any, Dict, Generator, List, Union | |||
| from maas_hub.snapshot_download import snapshot_download | |||
| from maas_lib.models import Model | |||
| from maas_lib.pipelines import util | |||
| from maas_lib.models.base import Model | |||
| from maas_lib.preprocessors import Preprocessor | |||
| from maas_lib.pydatasets import PyDataset | |||
| from maas_lib.utils.config import Config | |||
| from maas_lib.utils.hub import get_model_cache_dir | |||
| from .util import is_model_name | |||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
| @@ -26,7 +26,7 @@ class Pipeline(ABC): | |||
| def initiate_single_model(self, model): | |||
| if isinstance(model, str): | |||
| if not osp.exists(model): | |||
| cache_path = util.get_model_cache_dir(model) | |||
| cache_path = get_model_cache_dir(model) | |||
| model = cache_path if osp.exists( | |||
| cache_path) else snapshot_download(model) | |||
| return Model.from_pretrained(model) if is_model_name( | |||
| @@ -1,7 +1,7 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| from typing import List, Union | |||
| from typing import Union | |||
| import json | |||
| from maas_hub.file_download import model_file_download | |||
| @@ -10,7 +10,8 @@ from maas_lib.models.base import Model | |||
| from maas_lib.utils.config import Config, ConfigDict | |||
| from maas_lib.utils.constant import CONFIGFILE, Tasks | |||
| from maas_lib.utils.registry import Registry, build_from_cfg | |||
| from .base import InputModel, Pipeline | |||
| from .base import Pipeline | |||
| from .default import DEFAULT_MODEL_FOR_PIPELINE, get_default_pipeline_info | |||
| from .util import is_model_name | |||
| PIPELINES = Registry('pipelines') | |||
| @@ -32,7 +33,7 @@ def build_pipeline(cfg: ConfigDict, | |||
| def pipeline(task: str = None, | |||
| model: Union[InputModel, List[InputModel]] = None, | |||
| model: Union[str, Model] = None, | |||
| preprocessor=None, | |||
| config_file: str = None, | |||
| pipeline_name: str = None, | |||
| @@ -67,23 +68,19 @@ def pipeline(task: str = None, | |||
| if pipeline_name is None: | |||
| # get default pipeline for this task | |||
| assert task in PIPELINES.modules, f'No pipeline is registered for Task {task}' | |||
| pipeline_name = get_default_pipeline(task) | |||
| pipeline_name, default_model_repo = get_default_pipeline_info(task) | |||
| if model is None: | |||
| model = default_model_repo | |||
| assert isinstance(model, (type(None), str, Model)), \ | |||
| f'model should be either None, str or Model, but got {type(model)}' | |||
| cfg = ConfigDict(type=pipeline_name, model=model) | |||
| cfg = ConfigDict(type=pipeline_name) | |||
| if kwargs: | |||
| cfg.update(kwargs) | |||
| if model: | |||
| assert isinstance(model, (str, Model, List)), \ | |||
| f'model should be either (list of) str or Model, but got {type(model)}' | |||
| cfg.model = model | |||
| if preprocessor is not None: | |||
| cfg.preprocessor = preprocessor | |||
| return build_pipeline(cfg, task_name=task) | |||
| def get_default_pipeline(task): | |||
| return list(PIPELINES.modules[task].keys())[0] | |||
| @@ -0,0 +1,48 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from maas_lib.utils.constant import Tasks | |||
| DEFAULT_MODEL_FOR_PIPELINE = { | |||
| # TaskName: (pipeline_module_name, model_repo) | |||
| Tasks.image_matting: ('image-matting', 'damo/image-matting-person'), | |||
| Tasks.text_classification: | |||
| ('bert-sentiment-analysis', 'damo/bert-base-sst2'), | |||
| Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), | |||
| Tasks.image_captioning: ('ofa', None), | |||
| } | |||
| def add_default_pipeline_info(task: str, | |||
| model_name: str, | |||
| modelhub_name: str = None, | |||
| overwrite: bool = False): | |||
| """ Add default model for a task. | |||
| Args: | |||
| task (str): task name. | |||
| model_name (str): model_name. | |||
| modelhub_name (str): name for default modelhub. | |||
| overwrite (bool): overwrite default info. | |||
| """ | |||
| if not overwrite: | |||
| assert task not in DEFAULT_MODEL_FOR_PIPELINE, \ | |||
| f'task {task} already has default model.' | |||
| DEFAULT_MODEL_FOR_PIPELINE[task] = (model_name, modelhub_name) | |||
| def get_default_pipeline_info(task): | |||
| """ Get default info for certain task. | |||
| Args: | |||
| task (str): task name. | |||
| Return: | |||
| A tuple: first element is pipeline name(model_name), second element | |||
| is modelhub name. | |||
| """ | |||
| assert task in DEFAULT_MODEL_FOR_PIPELINE, \ | |||
| f'No default pipeline is registered for Task {task}' | |||
| pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task] | |||
| return pipeline_name, default_model | |||
| @@ -2,10 +2,6 @@ from typing import Any, Dict | |||
| import numpy as np | |||
| import torch | |||
| from fairseq import checkpoint_utils, tasks, utils | |||
| from ofa.models.ofa import OFAModel | |||
| from ofa.tasks.mm_tasks import CaptionTask | |||
| from ofa.utils.eval_utils import eval_caption | |||
| from PIL import Image | |||
| from maas_lib.pipelines.base import Input | |||
| @@ -24,6 +20,8 @@ class ImageCaptionPipeline(Pipeline): | |||
| def __init__(self, model: str, bpe_dir: str): | |||
| super().__init__() | |||
| # turn on cuda if GPU is available | |||
| from fairseq import checkpoint_utils, tasks, utils | |||
| from ofa.tasks.mm_tasks import CaptionTask | |||
| tasks.register_task('caption', CaptionTask) | |||
| use_cuda = False | |||
| @@ -106,6 +104,8 @@ class ImageCaptionPipeline(Pipeline): | |||
| return sample | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| from ofa.utils.eval_utils import eval_caption | |||
| results, _ = eval_caption(self.task, self.generator, self.models, | |||
| input) | |||
| return { | |||
| @@ -3,21 +3,11 @@ import os | |||
| import os.path as osp | |||
| import json | |||
| from maas_hub.constants import MODEL_ID_SEPARATOR | |||
| from maas_hub.file_download import model_file_download | |||
| from maas_lib.utils.constant import CONFIGFILE | |||
| # temp solution before the hub-cache is in place | |||
| def get_model_cache_dir(model_id: str, branch: str = 'master'): | |||
| model_id_expanded = model_id.replace('/', | |||
| MODEL_ID_SEPARATOR) + '.' + branch | |||
| default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas')) | |||
| return os.getenv('MAAS_CACHE', | |||
| os.path.join(default_cache_dir, 'hub', model_id_expanded)) | |||
| def is_model_name(model): | |||
| if osp.exists(model): | |||
| if osp.exists(osp.join(model, CONFIGFILE)): | |||
| @@ -0,0 +1,14 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from maas_hub.constants import MODEL_ID_SEPARATOR | |||
| # temp solution before the hub-cache is in place | |||
| def get_model_cache_dir(model_id: str, branch: str = 'master'): | |||
| model_id_expanded = model_id.replace('/', | |||
| MODEL_ID_SEPARATOR) + '.' + branch | |||
| default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas')) | |||
| return os.getenv('MAAS_CACHE', | |||
| os.path.join(default_cache_dir, 'hub', model_id_expanded)) | |||
| @@ -100,6 +100,12 @@ class Registry(object): | |||
| >>> class SwinTransformerDefaultGroup: | |||
| >>> pass | |||
| >>> class SwinTransformer2: | |||
| >>> pass | |||
| >>> MODELS.register_module('image-classification', | |||
| module_name='SwinT2', | |||
| module_cls=SwinTransformer2) | |||
| Args: | |||
| group_key: Group name of which module will be registered, | |||
| default group name is 'default' | |||
| @@ -8,6 +8,7 @@ import PIL | |||
| from maas_lib.pipelines import Pipeline, pipeline | |||
| from maas_lib.pipelines.builder import PIPELINES | |||
| from maas_lib.pipelines.default import add_default_pipeline_info | |||
| from maas_lib.utils.constant import Tasks | |||
| from maas_lib.utils.logger import get_logger | |||
| from maas_lib.utils.registry import default_group | |||
| @@ -75,6 +76,7 @@ class CustomPipelineTest(unittest.TestCase): | |||
| return inputs | |||
| self.assertTrue('custom-image' in PIPELINES.modules[default_group]) | |||
| add_default_pipeline_info(Tasks.image_tagging, 'custom-image') | |||
| pipe = pipeline(pipeline_name='custom-image') | |||
| pipe2 = pipeline(Tasks.image_tagging) | |||
| self.assertTrue(type(pipe) is type(pipe2)) | |||
| @@ -11,6 +11,7 @@ from maas_lib.utils.constant import Tasks | |||
| class ImageCaptionTest(unittest.TestCase): | |||
| @unittest.skip('skip long test') | |||
| def test_run(self): | |||
| model = 'https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt' | |||
| @@ -7,9 +7,10 @@ import unittest | |||
| import cv2 | |||
| from maas_lib.fileio import File | |||
| from maas_lib.pipelines import pipeline, util | |||
| from maas_lib.pipelines import pipeline | |||
| from maas_lib.pydatasets import PyDataset | |||
| from maas_lib.utils.constant import Tasks | |||
| from maas_lib.utils.hub import get_model_cache_dir | |||
| class ImageMattingTest(unittest.TestCase): | |||
| @@ -20,7 +21,7 @@ class ImageMattingTest(unittest.TestCase): | |||
| purge_cache = True | |||
| if purge_cache: | |||
| shutil.rmtree( | |||
| util.get_model_cache_dir(self.model_id), ignore_errors=True) | |||
| get_model_cache_dir(self.model_id), ignore_errors=True) | |||
| def test_run(self): | |||
| model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | |||
| @@ -59,6 +60,15 @@ class ImageMattingTest(unittest.TestCase): | |||
| cv2.imwrite('result.png', result['output_png']) | |||
| print(f'Output written to {osp.abspath("result.png")}') | |||
| def test_run_modelhub_default_model(self): | |||
| img_matting = pipeline(Tasks.image_matting) | |||
| result = img_matting( | |||
| 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | |||
| ) | |||
| cv2.imwrite('result.png', result['output_png']) | |||
| print(f'Output written to {osp.abspath("result.png")}') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -7,10 +7,11 @@ from pathlib import Path | |||
| from maas_lib.fileio import File | |||
| from maas_lib.models import Model | |||
| from maas_lib.models.nlp import BertForSequenceClassification | |||
| from maas_lib.pipelines import SequenceClassificationPipeline, pipeline, util | |||
| from maas_lib.pipelines import SequenceClassificationPipeline, pipeline | |||
| from maas_lib.preprocessors import SequenceClassificationPreprocessor | |||
| from maas_lib.pydatasets import PyDataset | |||
| from maas_lib.utils.constant import Tasks | |||
| from maas_lib.utils.hub import get_model_cache_dir | |||
| class SequenceClassificationTest(unittest.TestCase): | |||
| @@ -21,7 +22,7 @@ class SequenceClassificationTest(unittest.TestCase): | |||
| purge_cache = True | |||
| if purge_cache: | |||
| shutil.rmtree( | |||
| util.get_model_cache_dir(self.model_id), ignore_errors=True) | |||
| get_model_cache_dir(self.model_id), ignore_errors=True) | |||
| def predict(self, pipeline_ins: SequenceClassificationPipeline): | |||
| from easynlp.appzoo import load_dataset | |||
| @@ -83,6 +84,12 @@ class SequenceClassificationTest(unittest.TestCase): | |||
| PyDataset.load('glue', name='sst2', target='sentence')) | |||
| self.printDataset(result) | |||
| def test_run_with_default_model(self): | |||
| text_classification = pipeline(task=Tasks.text_classification) | |||
| result = text_classification( | |||
| PyDataset.load('glue', name='sst2', target='sentence')) | |||
| self.printDataset(result) | |||
| def test_run_with_dataset(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| preprocessor = SequenceClassificationPreprocessor( | |||
| @@ -15,6 +15,7 @@ class TextGenerationTest(unittest.TestCase): | |||
| input1 = "今日天气类型='晴'&温度变化趋势='大幅上升'&最低气温='28℃'&最高气温='31℃'&体感='湿热'" | |||
| input2 = "今日天气类型='多云'&体感='舒适'&最低气温='26℃'&最高气温='30℃'" | |||
| @unittest.skip('skip temporarily to save test time') | |||
| def test_run(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| preprocessor = TextGenerationPreprocessor( | |||
| @@ -41,6 +42,10 @@ class TextGenerationTest(unittest.TestCase): | |||
| task=Tasks.text_generation, model=self.model_id) | |||
| print(pipeline_ins(self.input2)) | |||
| def test_run_with_default_model(self): | |||
| pipeline_ins = pipeline(task=Tasks.text_generation) | |||
| print(pipeline_ins(self.input2)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||