Browse Source

[to #42322933]update ofa caption model

将caption pipeline的实现从pipeline下沉到model,并拆解preprocessor
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9081211

    * [to #41669377] docs and tools refinement and release 

1. add build_doc linter script
2. add sphinx-docs support
3. add development doc and api doc
4. change version to 0.1.0 for the first internal release version

Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8775307

* [to #41669377] add pipeline tutorial and fix bugs 

1. add pipleine tutorial
2. fix bugs when using pipeline with certain model and preprocessor

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8814301

* refine doc

* refine doc

* upload ofa for caption(with source code but not whl)

* remove data in gitignore

* append uncommitted data dir in ofa

* remove ofa_dir , use ofa.whl instead.

* update BPE

* rollback changes used in debugging.

* Merge branch 'master' into ofa/image_caption

# Conflicts:
#	docs/README.md
#	docs/source/conf.py
#	docs/source/index.rst
#	docs/source/tutorials/pipeline.md
#	maas_lib/models/nlp/sequence_classification_model.py
#	maas_lib/pipelines/builder.py
#	maas_lib/version.py
#	setup.py
#	tests/pipelines/test_text_classification.py

* 1. fix a bug in pipelines/builder.py.
2. modify model_path to model in image_captioning.py.

* 1. rename test_image_captioning.py.

* format all files using pre-commit.

* add fairseq in requirements.txt

* add fairseq in requirements.txt

* change fairseq path to git repo to a whl on oss in ofa.txt.

* change module_name to 'ofa'

* Merge remote-tracking branch 'origin/master' into ofa/image_caption

# Conflicts:
#	maas_lib/pipelines/builder.py

* optim requirements for ofa / refine image_captioning.py

* uncommited change.

* feat: Fix confilct, auto commit by WebIDE

* Merge remote-tracking branch 'origin/master' into ofa/image_caption

# Conflicts:
#	maas_lib/pipelines/multi_modal/__init__.py
#	modelscope/pipelines/multi_modal/image_captioning.py
#	tests/pipelines/test_image_captioning.py

* merge master

* merge master

* merge master

* rename

* Merge remote-tracking branch 'origin/master' into ofa/nlu

* add caption model

* Merge remote-tracking branch 'origin/master' into ofa/nlu

* update ofa caption model

* fix some typo, update unittest

* use local test image

* use local test image

* refactor, ofa -> multi_model

* merge master

* 删除 image_caption_pipeline.py
master
yichang.zyc huangjun.hj 3 years ago
parent
commit
a2cf3d619e
10 changed files with 168 additions and 69 deletions
  1. +3
    -0
      data/test/images/image_captioning.png
  2. +1
    -0
      modelscope/models/__init__.py
  3. +1
    -0
      modelscope/models/multi_model/__init__.py
  4. +80
    -0
      modelscope/models/multi_model/image_captioning_model.py
  5. +1
    -1
      modelscope/pipelines/builder.py
  6. +1
    -1
      modelscope/pipelines/multi_modal/__init__.py
  7. +33
    -0
      modelscope/pipelines/multi_modal/image_captioning_pipeline.py
  8. +1
    -0
      modelscope/preprocessors/__init__.py
  9. +41
    -48
      modelscope/preprocessors/multi_model.py
  10. +6
    -19
      tests/pipelines/test_image_captioning.py

+ 3
- 0
data/test/images/image_captioning.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:af83a94899a6d23339c3ecc5c4c58c57c835af57b531a2f4c50461184f820141
size 603621

+ 1
- 0
modelscope/models/__init__.py View File

@@ -4,4 +4,5 @@ from .audio.tts.am import SambertNetHifi16k
from .audio.tts.vocoder import Hifigan16k from .audio.tts.vocoder import Hifigan16k
from .base import Model from .base import Model
from .builder import MODELS, build_model from .builder import MODELS, build_model
from .multi_model import OfaForImageCaptioning
from .nlp import BertForSequenceClassification, SbertForSentenceSimilarity from .nlp import BertForSequenceClassification, SbertForSentenceSimilarity

+ 1
- 0
modelscope/models/multi_model/__init__.py View File

@@ -0,0 +1 @@
from .image_captioning_model import OfaForImageCaptioning

+ 80
- 0
modelscope/models/multi_model/image_captioning_model.py View File

@@ -0,0 +1,80 @@
import os.path as osp
from typing import Any, Dict

from PIL import Image

from modelscope.utils.constant import ModelFile, Tasks
from ..base import Model
from ..builder import MODELS

__all__ = ['OfaForImageCaptioning']


@MODELS.register_module(
Tasks.image_captioning, module_name=r'ofa-image-captioning')
class OfaForImageCaptioning(Model):

def __init__(self, model_dir, *args, **kwargs):
super().__init__(model_dir=model_dir, *args, **kwargs)
ckpt_name = ModelFile.TORCH_MODEL_FILE
local_model = osp.join(model_dir, ckpt_name)
bpe_dir = model_dir
# turn on cuda if GPU is available
from fairseq import checkpoint_utils, tasks, utils
from ofa.tasks.mm_tasks import CaptionTask
from ofa.utils.eval_utils import eval_caption
self.eval_caption = eval_caption

tasks.register_task('caption', CaptionTask)
use_cuda = kwargs['use_cuda'] if 'use_cuda' in kwargs else False
use_fp16 = kwargs[
'use_fp16'] if 'use_fp16' in kwargs and use_cuda else False
overrides = {
'bpe_dir': bpe_dir,
'eval_cider': False,
'beam': 5,
'max_len_b': 16,
'no_repeat_ngram_size': 3,
'seed': 7
}
models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
utils.split_paths(local_model), arg_overrides=overrides)

# Move models to GPU
for model in models:
model.eval()
if use_cuda:
model.cuda()
if use_fp16:
model.half()
model.prepare_for_inference_(cfg)
self.models = models
# Initialize generator
self.generator = task.build_generator(models, cfg.generation)

# Initialize transform
from torchvision import transforms
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize(
(cfg.task.patch_image_size, cfg.task.patch_image_size),
interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
self.task = task

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
results, _ = self.eval_caption(self.task, self.generator, self.models,
input)
return {
'image_id': results[0]['image_id'],
'caption': results[0]['caption']
}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
# What should we do here ?
return inputs

+ 1
- 1
modelscope/pipelines/builder.py View File

@@ -24,7 +24,7 @@ DEFAULT_MODEL_FOR_PIPELINE = {
('bert-sentiment-analysis', 'damo/bert-base-sst2'), ('bert-sentiment-analysis', 'damo/bert-base-sst2'),
Tasks.text_generation: ('palm2.0', Tasks.text_generation: ('palm2.0',
'damo/nlp_palm2.0_text-generation_chinese-base'), 'damo/nlp_palm2.0_text-generation_chinese-base'),
Tasks.image_captioning: ('ofa', None),
Tasks.image_captioning: ('ofa', 'damo/ofa_image-caption_coco_large_en'),
Tasks.image_generation: Tasks.image_generation:
('person-image-cartoon', ('person-image-cartoon',
'damo/cv_unet_person-image-cartoon_compound-models'), 'damo/cv_unet_person-image-cartoon_compound-models'),


+ 1
- 1
modelscope/pipelines/multi_modal/__init__.py View File

@@ -1 +1 @@
from .image_caption_pipeline import ImageCaptionPipeline
from .image_captioning_pipeline import ImageCaptionPipeline

+ 33
- 0
modelscope/pipelines/multi_modal/image_captioning_pipeline.py View File

@@ -0,0 +1,33 @@
from typing import Any, Dict, Union

from modelscope.preprocessors import OfaImageCaptionPreprocessor, Preprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from ..base import Model, Pipeline
from ..builder import PIPELINES

logger = get_logger()


@PIPELINES.register_module(Tasks.image_captioning, module_name='ofa')
class ImageCaptionPipeline(Pipeline):

def __init__(self,
model: Union[Model, str],
preprocessor: [Preprocessor] = None,
**kwargs):
super().__init__()
assert isinstance(model, str) or isinstance(model, Model), \
'model must be a single str or OfaForImageCaptioning'
if isinstance(model, str):
pipe_model = Model.from_pretrained(model)
elif isinstance(model, Model):
pipe_model = model
else:
raise NotImplementedError
if preprocessor is None and pipe_model:
preprocessor = OfaImageCaptionPreprocessor(model_dir=model)
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 1
- 0
modelscope/preprocessors/__init__.py View File

@@ -5,5 +5,6 @@ from .base import Preprocessor
from .builder import PREPROCESSORS, build_preprocessor from .builder import PREPROCESSORS, build_preprocessor
from .common import Compose from .common import Compose
from .image import LoadImage, load_image from .image import LoadImage, load_image
from .multi_model import OfaImageCaptionPreprocessor
from .nlp import * # noqa F403 from .nlp import * # noqa F403
from .text_to_speech import * # noqa F403 from .text_to_speech import * # noqa F403

modelscope/pipelines/multi_modal/image_caption_pipeline.py → modelscope/preprocessors/multi_model.py View File

@@ -1,32 +1,50 @@
from typing import Any, Dict
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from typing import Any, Dict, Union


import numpy as np import numpy as np
import torch import torch
from maas_hub.snapshot_download import snapshot_download
from PIL import Image from PIL import Image


from modelscope.pipelines.base import Input
from modelscope.preprocessors import load_image
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES
from modelscope.utils.constant import Fields, ModelFile
from modelscope.utils.hub import get_model_cache_dir
from modelscope.utils.type_assert import type_assert
from .base import Preprocessor
from .builder import PREPROCESSORS
from .image import load_image


logger = get_logger()
__all__ = [
'OfaImageCaptionPreprocessor',
]




@PIPELINES.register_module(Tasks.image_captioning, module_name='ofa')
class ImageCaptionPipeline(Pipeline):
# TODO: refine using modelhub
def __init__(self, model: str, bpe_dir: str):
super().__init__()
# turn on cuda if GPU is available
@PREPROCESSORS.register_module(
Fields.multi_modal, module_name=r'ofa-image-caption')
class OfaImageCaptionPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""
super().__init__(*args, **kwargs)

if osp.exists(model_dir):
local_model_dir = model_dir
else:
cache_path = get_model_cache_dir(model_dir)
local_model_dir = cache_path if osp.exists(
cache_path) else snapshot_download(model_dir)
local_model = osp.join(local_model_dir, ModelFile.TORCH_MODEL_FILE)
bpe_dir = local_model_dir

from fairseq import checkpoint_utils, tasks, utils from fairseq import checkpoint_utils, tasks, utils
from ofa.tasks.mm_tasks import CaptionTask from ofa.tasks.mm_tasks import CaptionTask


tasks.register_task('caption', CaptionTask) tasks.register_task('caption', CaptionTask)
use_cuda = False
# use fp16 only when GPU is available
use_fp16 = False

overrides = { overrides = {
'bpe_dir': bpe_dir, 'bpe_dir': bpe_dir,
'eval_cider': False, 'eval_cider': False,
@@ -35,21 +53,9 @@ class ImageCaptionPipeline(Pipeline):
'no_repeat_ngram_size': 3, 'no_repeat_ngram_size': 3,
'seed': 7 'seed': 7
} }
models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
utils.split_paths(model), arg_overrides=overrides)

# Move models to GPU
for model in models:
model.eval()
if use_cuda:
model.cuda()
if use_fp16:
model.half()
model.prepare_for_inference_(cfg)
self.models = models
# Initialize generator
self.generator = task.build_generator(models, cfg.generation)

model, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
utils.split_paths(local_model), arg_overrides=overrides)
del model
# Initialize transform # Initialize transform
from torchvision import transforms from torchvision import transforms
mean = [0.5, 0.5, 0.5] mean = [0.5, 0.5, 0.5]
@@ -69,7 +75,8 @@ class ImageCaptionPipeline(Pipeline):
self.eos_item = torch.LongTensor([task.src_dict.eos()]) self.eos_item = torch.LongTensor([task.src_dict.eos()])
self.pad_idx = task.src_dict.pad() self.pad_idx = task.src_dict.pad()


def preprocess(self, input: Input) -> Dict[str, Any]:
@type_assert(object, (str, tuple))
def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]:


def encode_text(text, length=None, append_bos=False, append_eos=False): def encode_text(text, length=None, append_bos=False, append_eos=False):
s = self.task.tgt_dict.encode_line( s = self.task.tgt_dict.encode_line(
@@ -88,7 +95,7 @@ class ImageCaptionPipeline(Pipeline):
patch_image = self.patch_resize_transform(input).unsqueeze(0) patch_image = self.patch_resize_transform(input).unsqueeze(0)
else: else:
patch_image = self.patch_resize_transform( patch_image = self.patch_resize_transform(
load_image(input)).unsqueeze(0)
load_image(data)).unsqueeze(0)
patch_mask = torch.tensor([True]) patch_mask = torch.tensor([True])
text = 'what does the image describe?' text = 'what does the image describe?'
src_text = encode_text( src_text = encode_text(
@@ -105,17 +112,3 @@ class ImageCaptionPipeline(Pipeline):
} }
} }
return sample return sample

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
from ofa.utils.eval_utils import eval_caption

results, _ = eval_caption(self.task, self.generator, self.models,
input)
return {
'image_id': results[0]['image_id'],
'caption': results[0]['caption']
}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
# What should we do here ?
return inputs

+ 6
- 19
tests/pipelines/test_image_captioning.py View File

@@ -1,10 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.


import os
import tempfile
import unittest import unittest


from modelscope.fileio import File
from modelscope.pipelines import pipeline from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level from modelscope.utils.test_utils import test_level
@@ -12,23 +9,13 @@ from modelscope.utils.test_utils import test_level


class ImageCaptionTest(unittest.TestCase): class ImageCaptionTest(unittest.TestCase):


@unittest.skip('skip before model is restored in model hub')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run(self): def test_run(self):
model = 'https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt'

os.system(
'wget https://jirenmr.oss-cn-zhangjiakou.aliyuncs.com/ofa/BPE.zip'
)
os.system('unzip BPE.zip')
bpe_dir = './BPE'

with tempfile.NamedTemporaryFile('wb', suffix='.pb') as ofile:
ofile.write(File.read(model))
img_captioning = pipeline(
Tasks.image_captioning, model=ofile.name, bpe_dir=bpe_dir)

result = img_captioning('data/test/images/image_matting.png')
print(result['caption'])
img_captioning = pipeline(
Tasks.image_captioning,
model='damo/ofa_image-caption_coco_large_en')
result = img_captioning('data/test/images/image_captioning.png')
print(result['caption'])




if __name__ == '__main__': if __name__ == '__main__':


Loading…
Cancel
Save