Browse Source

[to #42362853] add default model support and fix circular import

1. add default model support
2. fix circular import
3. temporarily skip ofa and palm test which costs too much time

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8981076
master
wenmeng.zwm 3 years ago
parent
commit
dd00195814
14 changed files with 124 additions and 38 deletions
  1. +6
    -0
      docs/source/develop.md
  2. +2
    -2
      maas_lib/models/base.py
  3. +3
    -3
      maas_lib/pipelines/base.py
  4. +12
    -15
      maas_lib/pipelines/builder.py
  5. +48
    -0
      maas_lib/pipelines/default.py
  6. +4
    -4
      maas_lib/pipelines/multi_modal/image_captioning.py
  7. +0
    -10
      maas_lib/pipelines/util.py
  8. +14
    -0
      maas_lib/utils/hub.py
  9. +6
    -0
      maas_lib/utils/registry.py
  10. +2
    -0
      tests/pipelines/test_base.py
  11. +1
    -0
      tests/pipelines/test_image_captioning.py
  12. +12
    -2
      tests/pipelines/test_image_matting.py
  13. +9
    -2
      tests/pipelines/test_text_classification.py
  14. +5
    -0
      tests/pipelines/test_text_generation.py

+ 6
- 0
docs/source/develop.md View File

@@ -71,12 +71,18 @@ TODO
* Feature
```shell
[to #AONE_ID] feat: commit title

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8949062

* commit msg1
* commit msg2
```
* Bugfix
```shell
[to #AONE_ID] fix: commit title

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8949062

* commit msg1
* commit msg2
```


+ 2
- 2
maas_lib/models/base.py View File

@@ -8,9 +8,9 @@ from maas_hub.file_download import model_file_download
from maas_hub.snapshot_download import snapshot_download

from maas_lib.models.builder import build_model
from maas_lib.pipelines import util
from maas_lib.utils.config import Config
from maas_lib.utils.constant import CONFIGFILE
from maas_lib.utils.hub import get_model_cache_dir

Tensor = Union['torch.Tensor', 'tf.Tensor']

@@ -40,7 +40,7 @@ class Model(ABC):
if osp.exists(model_name_or_path):
local_model_dir = model_name_or_path
else:
cache_path = util.get_model_cache_dir(model_name_or_path)
cache_path = get_model_cache_dir(model_name_or_path)
local_model_dir = cache_path if osp.exists(
cache_path) else snapshot_download(model_name_or_path)
# else:


+ 3
- 3
maas_lib/pipelines/base.py View File

@@ -6,11 +6,11 @@ from typing import Any, Dict, Generator, List, Union

from maas_hub.snapshot_download import snapshot_download

from maas_lib.models import Model
from maas_lib.pipelines import util
from maas_lib.models.base import Model
from maas_lib.preprocessors import Preprocessor
from maas_lib.pydatasets import PyDataset
from maas_lib.utils.config import Config
from maas_lib.utils.hub import get_model_cache_dir
from .util import is_model_name

Tensor = Union['torch.Tensor', 'tf.Tensor']
@@ -26,7 +26,7 @@ class Pipeline(ABC):
def initiate_single_model(self, model):
if isinstance(model, str):
if not osp.exists(model):
cache_path = util.get_model_cache_dir(model)
cache_path = get_model_cache_dir(model)
model = cache_path if osp.exists(
cache_path) else snapshot_download(model)
return Model.from_pretrained(model) if is_model_name(


+ 12
- 15
maas_lib/pipelines/builder.py View File

@@ -1,7 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os.path as osp
from typing import List, Union
from typing import Union

import json
from maas_hub.file_download import model_file_download
@@ -10,7 +10,8 @@ from maas_lib.models.base import Model
from maas_lib.utils.config import Config, ConfigDict
from maas_lib.utils.constant import CONFIGFILE, Tasks
from maas_lib.utils.registry import Registry, build_from_cfg
from .base import InputModel, Pipeline
from .base import Pipeline
from .default import DEFAULT_MODEL_FOR_PIPELINE, get_default_pipeline_info
from .util import is_model_name

PIPELINES = Registry('pipelines')
@@ -32,7 +33,7 @@ def build_pipeline(cfg: ConfigDict,


def pipeline(task: str = None,
model: Union[InputModel, List[InputModel]] = None,
model: Union[str, Model] = None,
preprocessor=None,
config_file: str = None,
pipeline_name: str = None,
@@ -67,23 +68,19 @@ def pipeline(task: str = None,

if pipeline_name is None:
# get default pipeline for this task
assert task in PIPELINES.modules, f'No pipeline is registered for Task {task}'
pipeline_name = get_default_pipeline(task)
pipeline_name, default_model_repo = get_default_pipeline_info(task)
if model is None:
model = default_model_repo

assert isinstance(model, (type(None), str, Model)), \
f'model should be either None, str or Model, but got {type(model)}'

cfg = ConfigDict(type=pipeline_name, model=model)

cfg = ConfigDict(type=pipeline_name)
if kwargs:
cfg.update(kwargs)

if model:
assert isinstance(model, (str, Model, List)), \
f'model should be either (list of) str or Model, but got {type(model)}'
cfg.model = model

if preprocessor is not None:
cfg.preprocessor = preprocessor

return build_pipeline(cfg, task_name=task)


def get_default_pipeline(task):
return list(PIPELINES.modules[task].keys())[0]

+ 48
- 0
maas_lib/pipelines/default.py View File

@@ -0,0 +1,48 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from maas_lib.utils.constant import Tasks

DEFAULT_MODEL_FOR_PIPELINE = {
# TaskName: (pipeline_module_name, model_repo)
Tasks.image_matting: ('image-matting', 'damo/image-matting-person'),
Tasks.text_classification:
('bert-sentiment-analysis', 'damo/bert-base-sst2'),
Tasks.text_generation: ('palm', 'damo/nlp_palm_text-generation_chinese'),
Tasks.image_captioning: ('ofa', None),
}


def add_default_pipeline_info(task: str,
model_name: str,
modelhub_name: str = None,
overwrite: bool = False):
""" Add default model for a task.

Args:
task (str): task name.
model_name (str): model_name.
modelhub_name (str): name for default modelhub.
overwrite (bool): overwrite default info.
"""
if not overwrite:
assert task not in DEFAULT_MODEL_FOR_PIPELINE, \
f'task {task} already has default model.'

DEFAULT_MODEL_FOR_PIPELINE[task] = (model_name, modelhub_name)


def get_default_pipeline_info(task):
""" Get default info for certain task.

Args:
task (str): task name.

Return:
A tuple: first element is pipeline name(model_name), second element
is modelhub name.
"""
assert task in DEFAULT_MODEL_FOR_PIPELINE, \
f'No default pipeline is registered for Task {task}'

pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task]
return pipeline_name, default_model

+ 4
- 4
maas_lib/pipelines/multi_modal/image_captioning.py View File

@@ -2,10 +2,6 @@ from typing import Any, Dict

import numpy as np
import torch
from fairseq import checkpoint_utils, tasks, utils
from ofa.models.ofa import OFAModel
from ofa.tasks.mm_tasks import CaptionTask
from ofa.utils.eval_utils import eval_caption
from PIL import Image

from maas_lib.pipelines.base import Input
@@ -24,6 +20,8 @@ class ImageCaptionPipeline(Pipeline):
def __init__(self, model: str, bpe_dir: str):
super().__init__()
# turn on cuda if GPU is available
from fairseq import checkpoint_utils, tasks, utils
from ofa.tasks.mm_tasks import CaptionTask

tasks.register_task('caption', CaptionTask)
use_cuda = False
@@ -106,6 +104,8 @@ class ImageCaptionPipeline(Pipeline):
return sample

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
from ofa.utils.eval_utils import eval_caption

results, _ = eval_caption(self.task, self.generator, self.models,
input)
return {


+ 0
- 10
maas_lib/pipelines/util.py View File

@@ -3,21 +3,11 @@ import os
import os.path as osp

import json
from maas_hub.constants import MODEL_ID_SEPARATOR
from maas_hub.file_download import model_file_download

from maas_lib.utils.constant import CONFIGFILE


# temp solution before the hub-cache is in place
def get_model_cache_dir(model_id: str, branch: str = 'master'):
model_id_expanded = model_id.replace('/',
MODEL_ID_SEPARATOR) + '.' + branch
default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas'))
return os.getenv('MAAS_CACHE',
os.path.join(default_cache_dir, 'hub', model_id_expanded))


def is_model_name(model):
if osp.exists(model):
if osp.exists(osp.join(model, CONFIGFILE)):


+ 14
- 0
maas_lib/utils/hub.py View File

@@ -0,0 +1,14 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os

from maas_hub.constants import MODEL_ID_SEPARATOR


# temp solution before the hub-cache is in place
def get_model_cache_dir(model_id: str, branch: str = 'master'):
model_id_expanded = model_id.replace('/',
MODEL_ID_SEPARATOR) + '.' + branch
default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas'))
return os.getenv('MAAS_CACHE',
os.path.join(default_cache_dir, 'hub', model_id_expanded))

+ 6
- 0
maas_lib/utils/registry.py View File

@@ -100,6 +100,12 @@ class Registry(object):
>>> class SwinTransformerDefaultGroup:
>>> pass

>>> class SwinTransformer2:
>>> pass
>>> MODELS.register_module('image-classification',
module_name='SwinT2',
module_cls=SwinTransformer2)

Args:
group_key: Group name of which module will be registered,
default group name is 'default'


+ 2
- 0
tests/pipelines/test_base.py View File

@@ -8,6 +8,7 @@ import PIL

from maas_lib.pipelines import Pipeline, pipeline
from maas_lib.pipelines.builder import PIPELINES
from maas_lib.pipelines.default import add_default_pipeline_info
from maas_lib.utils.constant import Tasks
from maas_lib.utils.logger import get_logger
from maas_lib.utils.registry import default_group
@@ -75,6 +76,7 @@ class CustomPipelineTest(unittest.TestCase):
return inputs

self.assertTrue('custom-image' in PIPELINES.modules[default_group])
add_default_pipeline_info(Tasks.image_tagging, 'custom-image')
pipe = pipeline(pipeline_name='custom-image')
pipe2 = pipeline(Tasks.image_tagging)
self.assertTrue(type(pipe) is type(pipe2))


+ 1
- 0
tests/pipelines/test_image_captioning.py View File

@@ -11,6 +11,7 @@ from maas_lib.utils.constant import Tasks

class ImageCaptionTest(unittest.TestCase):

@unittest.skip('skip long test')
def test_run(self):
model = 'https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt'



+ 12
- 2
tests/pipelines/test_image_matting.py View File

@@ -7,9 +7,10 @@ import unittest
import cv2

from maas_lib.fileio import File
from maas_lib.pipelines import pipeline, util
from maas_lib.pipelines import pipeline
from maas_lib.pydatasets import PyDataset
from maas_lib.utils.constant import Tasks
from maas_lib.utils.hub import get_model_cache_dir


class ImageMattingTest(unittest.TestCase):
@@ -20,7 +21,7 @@ class ImageMattingTest(unittest.TestCase):
purge_cache = True
if purge_cache:
shutil.rmtree(
util.get_model_cache_dir(self.model_id), ignore_errors=True)
get_model_cache_dir(self.model_id), ignore_errors=True)

def test_run(self):
model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \
@@ -59,6 +60,15 @@ class ImageMattingTest(unittest.TestCase):
cv2.imwrite('result.png', result['output_png'])
print(f'Output written to {osp.abspath("result.png")}')

def test_run_modelhub_default_model(self):
img_matting = pipeline(Tasks.image_matting)

result = img_matting(
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png'
)
cv2.imwrite('result.png', result['output_png'])
print(f'Output written to {osp.abspath("result.png")}')


if __name__ == '__main__':
unittest.main()

+ 9
- 2
tests/pipelines/test_text_classification.py View File

@@ -7,10 +7,11 @@ from pathlib import Path
from maas_lib.fileio import File
from maas_lib.models import Model
from maas_lib.models.nlp import BertForSequenceClassification
from maas_lib.pipelines import SequenceClassificationPipeline, pipeline, util
from maas_lib.pipelines import SequenceClassificationPipeline, pipeline
from maas_lib.preprocessors import SequenceClassificationPreprocessor
from maas_lib.pydatasets import PyDataset
from maas_lib.utils.constant import Tasks
from maas_lib.utils.hub import get_model_cache_dir


class SequenceClassificationTest(unittest.TestCase):
@@ -21,7 +22,7 @@ class SequenceClassificationTest(unittest.TestCase):
purge_cache = True
if purge_cache:
shutil.rmtree(
util.get_model_cache_dir(self.model_id), ignore_errors=True)
get_model_cache_dir(self.model_id), ignore_errors=True)

def predict(self, pipeline_ins: SequenceClassificationPipeline):
from easynlp.appzoo import load_dataset
@@ -83,6 +84,12 @@ class SequenceClassificationTest(unittest.TestCase):
PyDataset.load('glue', name='sst2', target='sentence'))
self.printDataset(result)

def test_run_with_default_model(self):
text_classification = pipeline(task=Tasks.text_classification)
result = text_classification(
PyDataset.load('glue', name='sst2', target='sentence'))
self.printDataset(result)

def test_run_with_dataset(self):
model = Model.from_pretrained(self.model_id)
preprocessor = SequenceClassificationPreprocessor(


+ 5
- 0
tests/pipelines/test_text_generation.py View File

@@ -15,6 +15,7 @@ class TextGenerationTest(unittest.TestCase):
input1 = "今日天气类型='晴'&温度变化趋势='大幅上升'&最低气温='28℃'&最高气温='31℃'&体感='湿热'"
input2 = "今日天气类型='多云'&体感='舒适'&最低气温='26℃'&最高气温='30℃'"

@unittest.skip('skip temporarily to save test time')
def test_run(self):
cache_path = snapshot_download(self.model_id)
preprocessor = TextGenerationPreprocessor(
@@ -41,6 +42,10 @@ class TextGenerationTest(unittest.TestCase):
task=Tasks.text_generation, model=self.model_id)
print(pipeline_ins(self.input2))

def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.text_generation)
print(pipeline_ins(self.input2))


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save