1. add default model support 2. fix circular import 3. temporarily skip ofa and palm test which costs too much time Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8981076master
| @@ -71,12 +71,18 @@ TODO | |||||
| * Feature | * Feature | ||||
| ```shell | ```shell | ||||
| [to #AONE_ID] feat: commit title | [to #AONE_ID] feat: commit title | ||||
| Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8949062 | |||||
| * commit msg1 | * commit msg1 | ||||
| * commit msg2 | * commit msg2 | ||||
| ``` | ``` | ||||
| * Bugfix | * Bugfix | ||||
| ```shell | ```shell | ||||
| [to #AONE_ID] fix: commit title | [to #AONE_ID] fix: commit title | ||||
| Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8949062 | |||||
| * commit msg1 | * commit msg1 | ||||
| * commit msg2 | * commit msg2 | ||||
| ``` | ``` | ||||
| @@ -8,9 +8,9 @@ from maas_hub.file_download import model_file_download | |||||
| from maas_hub.snapshot_download import snapshot_download | from maas_hub.snapshot_download import snapshot_download | ||||
| from maas_lib.models.builder import build_model | from maas_lib.models.builder import build_model | ||||
| from maas_lib.pipelines import util | |||||
| from maas_lib.utils.config import Config | from maas_lib.utils.config import Config | ||||
| from maas_lib.utils.constant import CONFIGFILE | from maas_lib.utils.constant import CONFIGFILE | ||||
| from maas_lib.utils.hub import get_model_cache_dir | |||||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | Tensor = Union['torch.Tensor', 'tf.Tensor'] | ||||
| @@ -40,7 +40,7 @@ class Model(ABC): | |||||
| if osp.exists(model_name_or_path): | if osp.exists(model_name_or_path): | ||||
| local_model_dir = model_name_or_path | local_model_dir = model_name_or_path | ||||
| else: | else: | ||||
| cache_path = util.get_model_cache_dir(model_name_or_path) | |||||
| cache_path = get_model_cache_dir(model_name_or_path) | |||||
| local_model_dir = cache_path if osp.exists( | local_model_dir = cache_path if osp.exists( | ||||
| cache_path) else snapshot_download(model_name_or_path) | cache_path) else snapshot_download(model_name_or_path) | ||||
| # else: | # else: | ||||
| @@ -6,11 +6,11 @@ from typing import Any, Dict, Generator, List, Union | |||||
| from maas_hub.snapshot_download import snapshot_download | from maas_hub.snapshot_download import snapshot_download | ||||
| from maas_lib.models import Model | |||||
| from maas_lib.pipelines import util | |||||
| from maas_lib.models.base import Model | |||||
| from maas_lib.preprocessors import Preprocessor | from maas_lib.preprocessors import Preprocessor | ||||
| from maas_lib.pydatasets import PyDataset | from maas_lib.pydatasets import PyDataset | ||||
| from maas_lib.utils.config import Config | from maas_lib.utils.config import Config | ||||
| from maas_lib.utils.hub import get_model_cache_dir | |||||
| from .util import is_model_name | from .util import is_model_name | ||||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | Tensor = Union['torch.Tensor', 'tf.Tensor'] | ||||
| @@ -26,7 +26,7 @@ class Pipeline(ABC): | |||||
| def initiate_single_model(self, model): | def initiate_single_model(self, model): | ||||
| if isinstance(model, str): | if isinstance(model, str): | ||||
| if not osp.exists(model): | if not osp.exists(model): | ||||
| cache_path = util.get_model_cache_dir(model) | |||||
| cache_path = get_model_cache_dir(model) | |||||
| model = cache_path if osp.exists( | model = cache_path if osp.exists( | ||||
| cache_path) else snapshot_download(model) | cache_path) else snapshot_download(model) | ||||
| return Model.from_pretrained(model) if is_model_name( | return Model.from_pretrained(model) if is_model_name( | ||||
| @@ -1,7 +1,7 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os.path as osp | import os.path as osp | ||||
| from typing import List, Union | |||||
| from typing import Union | |||||
| import json | import json | ||||
| from maas_hub.file_download import model_file_download | from maas_hub.file_download import model_file_download | ||||
| @@ -10,7 +10,8 @@ from maas_lib.models.base import Model | |||||
| from maas_lib.utils.config import Config, ConfigDict | from maas_lib.utils.config import Config, ConfigDict | ||||
| from maas_lib.utils.constant import CONFIGFILE, Tasks | from maas_lib.utils.constant import CONFIGFILE, Tasks | ||||
| from maas_lib.utils.registry import Registry, build_from_cfg | from maas_lib.utils.registry import Registry, build_from_cfg | ||||
| from .base import InputModel, Pipeline | |||||
| from .base import Pipeline | |||||
| from .default import DEFAULT_MODEL_FOR_PIPELINE, get_default_pipeline_info | |||||
| from .util import is_model_name | from .util import is_model_name | ||||
| PIPELINES = Registry('pipelines') | PIPELINES = Registry('pipelines') | ||||
| @@ -32,7 +33,7 @@ def build_pipeline(cfg: ConfigDict, | |||||
| def pipeline(task: str = None, | def pipeline(task: str = None, | ||||
| model: Union[InputModel, List[InputModel]] = None, | |||||
| model: Union[str, Model] = None, | |||||
| preprocessor=None, | preprocessor=None, | ||||
| config_file: str = None, | config_file: str = None, | ||||
| pipeline_name: str = None, | pipeline_name: str = None, | ||||
| @@ -67,23 +68,19 @@ def pipeline(task: str = None, | |||||
| if pipeline_name is None: | if pipeline_name is None: | ||||
| # get default pipeline for this task | # get default pipeline for this task | ||||
| assert task in PIPELINES.modules, f'No pipeline is registered for Task {task}' | |||||
| pipeline_name = get_default_pipeline(task) | |||||
| pipeline_name, default_model_repo = get_default_pipeline_info(task) | |||||
| if model is None: | |||||
| model = default_model_repo | |||||
| assert isinstance(model, (type(None), str, Model)), \ | |||||
| f'model should be either None, str or Model, but got {type(model)}' | |||||
| cfg = ConfigDict(type=pipeline_name, model=model) | |||||
| cfg = ConfigDict(type=pipeline_name) | |||||
| if kwargs: | if kwargs: | ||||
| cfg.update(kwargs) | cfg.update(kwargs) | ||||
| if model: | |||||
| assert isinstance(model, (str, Model, List)), \ | |||||
| f'model should be either (list of) str or Model, but got {type(model)}' | |||||
| cfg.model = model | |||||
| if preprocessor is not None: | if preprocessor is not None: | ||||
| cfg.preprocessor = preprocessor | cfg.preprocessor = preprocessor | ||||
| return build_pipeline(cfg, task_name=task) | return build_pipeline(cfg, task_name=task) | ||||
| def get_default_pipeline(task): | |||||
| return list(PIPELINES.modules[task].keys())[0] | |||||
| @@ -0,0 +1,48 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from maas_lib.utils.constant import Tasks | |||||
| DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| # TaskName: (pipeline_module_name, model_repo) | |||||
| Tasks.image_matting: ('image-matting', 'damo/image-matting-person'), | |||||
| Tasks.text_classification: | |||||
| ('bert-sentiment-analysis', 'damo/bert-base-sst2'), | |||||
| Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'), | |||||
| Tasks.image_captioning: ('ofa', None), | |||||
| } | |||||
| def add_default_pipeline_info(task: str, | |||||
| model_name: str, | |||||
| modelhub_name: str = None, | |||||
| overwrite: bool = False): | |||||
| """ Add default model for a task. | |||||
| Args: | |||||
| task (str): task name. | |||||
| model_name (str): model_name. | |||||
| modelhub_name (str): name for default modelhub. | |||||
| overwrite (bool): overwrite default info. | |||||
| """ | |||||
| if not overwrite: | |||||
| assert task not in DEFAULT_MODEL_FOR_PIPELINE, \ | |||||
| f'task {task} already has default model.' | |||||
| DEFAULT_MODEL_FOR_PIPELINE[task] = (model_name, modelhub_name) | |||||
| def get_default_pipeline_info(task): | |||||
| """ Get default info for certain task. | |||||
| Args: | |||||
| task (str): task name. | |||||
| Return: | |||||
| A tuple: first element is pipeline name(model_name), second element | |||||
| is modelhub name. | |||||
| """ | |||||
| assert task in DEFAULT_MODEL_FOR_PIPELINE, \ | |||||
| f'No default pipeline is registered for Task {task}' | |||||
| pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task] | |||||
| return pipeline_name, default_model | |||||
| @@ -2,10 +2,6 @@ from typing import Any, Dict | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| from fairseq import checkpoint_utils, tasks, utils | |||||
| from ofa.models.ofa import OFAModel | |||||
| from ofa.tasks.mm_tasks import CaptionTask | |||||
| from ofa.utils.eval_utils import eval_caption | |||||
| from PIL import Image | from PIL import Image | ||||
| from maas_lib.pipelines.base import Input | from maas_lib.pipelines.base import Input | ||||
| @@ -24,6 +20,8 @@ class ImageCaptionPipeline(Pipeline): | |||||
| def __init__(self, model: str, bpe_dir: str): | def __init__(self, model: str, bpe_dir: str): | ||||
| super().__init__() | super().__init__() | ||||
| # turn on cuda if GPU is available | # turn on cuda if GPU is available | ||||
| from fairseq import checkpoint_utils, tasks, utils | |||||
| from ofa.tasks.mm_tasks import CaptionTask | |||||
| tasks.register_task('caption', CaptionTask) | tasks.register_task('caption', CaptionTask) | ||||
| use_cuda = False | use_cuda = False | ||||
| @@ -106,6 +104,8 @@ class ImageCaptionPipeline(Pipeline): | |||||
| return sample | return sample | ||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | ||||
| from ofa.utils.eval_utils import eval_caption | |||||
| results, _ = eval_caption(self.task, self.generator, self.models, | results, _ = eval_caption(self.task, self.generator, self.models, | ||||
| input) | input) | ||||
| return { | return { | ||||
| @@ -3,21 +3,11 @@ import os | |||||
| import os.path as osp | import os.path as osp | ||||
| import json | import json | ||||
| from maas_hub.constants import MODEL_ID_SEPARATOR | |||||
| from maas_hub.file_download import model_file_download | from maas_hub.file_download import model_file_download | ||||
| from maas_lib.utils.constant import CONFIGFILE | from maas_lib.utils.constant import CONFIGFILE | ||||
| # temp solution before the hub-cache is in place | |||||
| def get_model_cache_dir(model_id: str, branch: str = 'master'): | |||||
| model_id_expanded = model_id.replace('/', | |||||
| MODEL_ID_SEPARATOR) + '.' + branch | |||||
| default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas')) | |||||
| return os.getenv('MAAS_CACHE', | |||||
| os.path.join(default_cache_dir, 'hub', model_id_expanded)) | |||||
| def is_model_name(model): | def is_model_name(model): | ||||
| if osp.exists(model): | if osp.exists(model): | ||||
| if osp.exists(osp.join(model, CONFIGFILE)): | if osp.exists(osp.join(model, CONFIGFILE)): | ||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| from maas_hub.constants import MODEL_ID_SEPARATOR | |||||
| # temp solution before the hub-cache is in place | |||||
| def get_model_cache_dir(model_id: str, branch: str = 'master'): | |||||
| model_id_expanded = model_id.replace('/', | |||||
| MODEL_ID_SEPARATOR) + '.' + branch | |||||
| default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas')) | |||||
| return os.getenv('MAAS_CACHE', | |||||
| os.path.join(default_cache_dir, 'hub', model_id_expanded)) | |||||
| @@ -100,6 +100,12 @@ class Registry(object): | |||||
| >>> class SwinTransformerDefaultGroup: | >>> class SwinTransformerDefaultGroup: | ||||
| >>> pass | >>> pass | ||||
| >>> class SwinTransformer2: | |||||
| >>> pass | |||||
| >>> MODELS.register_module('image-classification', | |||||
| module_name='SwinT2', | |||||
| module_cls=SwinTransformer2) | |||||
| Args: | Args: | ||||
| group_key: Group name of which module will be registered, | group_key: Group name of which module will be registered, | ||||
| default group name is 'default' | default group name is 'default' | ||||
| @@ -8,6 +8,7 @@ import PIL | |||||
| from maas_lib.pipelines import Pipeline, pipeline | from maas_lib.pipelines import Pipeline, pipeline | ||||
| from maas_lib.pipelines.builder import PIPELINES | from maas_lib.pipelines.builder import PIPELINES | ||||
| from maas_lib.pipelines.default import add_default_pipeline_info | |||||
| from maas_lib.utils.constant import Tasks | from maas_lib.utils.constant import Tasks | ||||
| from maas_lib.utils.logger import get_logger | from maas_lib.utils.logger import get_logger | ||||
| from maas_lib.utils.registry import default_group | from maas_lib.utils.registry import default_group | ||||
| @@ -75,6 +76,7 @@ class CustomPipelineTest(unittest.TestCase): | |||||
| return inputs | return inputs | ||||
| self.assertTrue('custom-image' in PIPELINES.modules[default_group]) | self.assertTrue('custom-image' in PIPELINES.modules[default_group]) | ||||
| add_default_pipeline_info(Tasks.image_tagging, 'custom-image') | |||||
| pipe = pipeline(pipeline_name='custom-image') | pipe = pipeline(pipeline_name='custom-image') | ||||
| pipe2 = pipeline(Tasks.image_tagging) | pipe2 = pipeline(Tasks.image_tagging) | ||||
| self.assertTrue(type(pipe) is type(pipe2)) | self.assertTrue(type(pipe) is type(pipe2)) | ||||
| @@ -11,6 +11,7 @@ from maas_lib.utils.constant import Tasks | |||||
| class ImageCaptionTest(unittest.TestCase): | class ImageCaptionTest(unittest.TestCase): | ||||
| @unittest.skip('skip long test') | |||||
| def test_run(self): | def test_run(self): | ||||
| model = 'https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt' | model = 'https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt' | ||||
| @@ -7,9 +7,10 @@ import unittest | |||||
| import cv2 | import cv2 | ||||
| from maas_lib.fileio import File | from maas_lib.fileio import File | ||||
| from maas_lib.pipelines import pipeline, util | |||||
| from maas_lib.pipelines import pipeline | |||||
| from maas_lib.pydatasets import PyDataset | from maas_lib.pydatasets import PyDataset | ||||
| from maas_lib.utils.constant import Tasks | from maas_lib.utils.constant import Tasks | ||||
| from maas_lib.utils.hub import get_model_cache_dir | |||||
| class ImageMattingTest(unittest.TestCase): | class ImageMattingTest(unittest.TestCase): | ||||
| @@ -20,7 +21,7 @@ class ImageMattingTest(unittest.TestCase): | |||||
| purge_cache = True | purge_cache = True | ||||
| if purge_cache: | if purge_cache: | ||||
| shutil.rmtree( | shutil.rmtree( | ||||
| util.get_model_cache_dir(self.model_id), ignore_errors=True) | |||||
| get_model_cache_dir(self.model_id), ignore_errors=True) | |||||
| def test_run(self): | def test_run(self): | ||||
| model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | ||||
| @@ -59,6 +60,15 @@ class ImageMattingTest(unittest.TestCase): | |||||
| cv2.imwrite('result.png', result['output_png']) | cv2.imwrite('result.png', result['output_png']) | ||||
| print(f'Output written to {osp.abspath("result.png")}') | print(f'Output written to {osp.abspath("result.png")}') | ||||
| def test_run_modelhub_default_model(self): | |||||
| img_matting = pipeline(Tasks.image_matting) | |||||
| result = img_matting( | |||||
| 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | |||||
| ) | |||||
| cv2.imwrite('result.png', result['output_png']) | |||||
| print(f'Output written to {osp.abspath("result.png")}') | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||
| @@ -7,10 +7,11 @@ from pathlib import Path | |||||
| from maas_lib.fileio import File | from maas_lib.fileio import File | ||||
| from maas_lib.models import Model | from maas_lib.models import Model | ||||
| from maas_lib.models.nlp import BertForSequenceClassification | from maas_lib.models.nlp import BertForSequenceClassification | ||||
| from maas_lib.pipelines import SequenceClassificationPipeline, pipeline, util | |||||
| from maas_lib.pipelines import SequenceClassificationPipeline, pipeline | |||||
| from maas_lib.preprocessors import SequenceClassificationPreprocessor | from maas_lib.preprocessors import SequenceClassificationPreprocessor | ||||
| from maas_lib.pydatasets import PyDataset | from maas_lib.pydatasets import PyDataset | ||||
| from maas_lib.utils.constant import Tasks | from maas_lib.utils.constant import Tasks | ||||
| from maas_lib.utils.hub import get_model_cache_dir | |||||
| class SequenceClassificationTest(unittest.TestCase): | class SequenceClassificationTest(unittest.TestCase): | ||||
| @@ -21,7 +22,7 @@ class SequenceClassificationTest(unittest.TestCase): | |||||
| purge_cache = True | purge_cache = True | ||||
| if purge_cache: | if purge_cache: | ||||
| shutil.rmtree( | shutil.rmtree( | ||||
| util.get_model_cache_dir(self.model_id), ignore_errors=True) | |||||
| get_model_cache_dir(self.model_id), ignore_errors=True) | |||||
| def predict(self, pipeline_ins: SequenceClassificationPipeline): | def predict(self, pipeline_ins: SequenceClassificationPipeline): | ||||
| from easynlp.appzoo import load_dataset | from easynlp.appzoo import load_dataset | ||||
| @@ -83,6 +84,12 @@ class SequenceClassificationTest(unittest.TestCase): | |||||
| PyDataset.load('glue', name='sst2', target='sentence')) | PyDataset.load('glue', name='sst2', target='sentence')) | ||||
| self.printDataset(result) | self.printDataset(result) | ||||
| def test_run_with_default_model(self): | |||||
| text_classification = pipeline(task=Tasks.text_classification) | |||||
| result = text_classification( | |||||
| PyDataset.load('glue', name='sst2', target='sentence')) | |||||
| self.printDataset(result) | |||||
| def test_run_with_dataset(self): | def test_run_with_dataset(self): | ||||
| model = Model.from_pretrained(self.model_id) | model = Model.from_pretrained(self.model_id) | ||||
| preprocessor = SequenceClassificationPreprocessor( | preprocessor = SequenceClassificationPreprocessor( | ||||
| @@ -15,6 +15,7 @@ class TextGenerationTest(unittest.TestCase): | |||||
| input1 = "今日天气类型='晴'&温度变化趋势='大幅上升'&最低气温='28℃'&最高气温='31℃'&体感='湿热'" | input1 = "今日天气类型='晴'&温度变化趋势='大幅上升'&最低气温='28℃'&最高气温='31℃'&体感='湿热'" | ||||
| input2 = "今日天气类型='多云'&体感='舒适'&最低气温='26℃'&最高气温='30℃'" | input2 = "今日天气类型='多云'&体感='舒适'&最低气温='26℃'&最高气温='30℃'" | ||||
| @unittest.skip('skip temporarily to save test time') | |||||
| def test_run(self): | def test_run(self): | ||||
| cache_path = snapshot_download(self.model_id) | cache_path = snapshot_download(self.model_id) | ||||
| preprocessor = TextGenerationPreprocessor( | preprocessor = TextGenerationPreprocessor( | ||||
| @@ -41,6 +42,10 @@ class TextGenerationTest(unittest.TestCase): | |||||
| task=Tasks.text_generation, model=self.model_id) | task=Tasks.text_generation, model=self.model_id) | ||||
| print(pipeline_ins(self.input2)) | print(pipeline_ins(self.input2)) | ||||
| def test_run_with_default_model(self): | |||||
| pipeline_ins = pipeline(task=Tasks.text_generation) | |||||
| print(pipeline_ins(self.input2)) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||