Browse Source

[to #42322933] add image_caption_pipeline with OFA

1. add OFA whl for image caption pipeline
2. fix a bug in pipelines/builder.py
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8930942

    * [to #41669377] docs and tools refinement and release 

1. add build_doc linter script
2. add sphinx-docs support
3. add development doc and api doc
4. change version to 0.1.0 for the first internal release version

Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8775307

* [to #41669377] add pipeline tutorial and fix bugs 

1. add pipleine tutorial
2. fix bugs when using pipeline with certain model and preprocessor

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8814301

* refine doc

* refine doc

* upload ofa for caption(with source code but not whl)

* remove data in gitignore

* append uncommitted data dir in ofa

* remove ofa_dir , use ofa.whl instead.

* update BPE

* rollback changes used in debugging.

* Merge branch 'master' into ofa/image_caption

# Conflicts:
#	docs/README.md
#	docs/source/conf.py
#	docs/source/index.rst
#	docs/source/tutorials/pipeline.md
#	maas_lib/models/nlp/sequence_classification_model.py
#	maas_lib/pipelines/builder.py
#	maas_lib/version.py
#	setup.py
#	tests/pipelines/test_text_classification.py

* 1. fix a bug in pipelines/builder.py.
2. modify model_path to model in image_captioning.py.

* 1. rename test_image_captioning.py.

* format all files using pre-commit.

* add fairseq in requirements.txt

* add fairseq in requirements.txt

* change fairseq path to git repo to a whl on oss in ofa.txt.

* change module_name to 'ofa'

* Merge remote-tracking branch 'origin/master' into ofa/image_caption

# Conflicts:
#	maas_lib/pipelines/builder.py

* optim requirements for ofa / refine image_captioning.py

* uncommited change.

* feat: Fix confilct, auto commit by WebIDE
master
menrui.mr huangjun.hj 3 years ago
parent
commit
68e64a90d4
5 changed files with 164 additions and 0 deletions
  1. +1
    -0
      maas_lib/pipelines/multi_modal/__init__.py
  2. +118
    -0
      maas_lib/pipelines/multi_modal/image_captioning.py
  3. +1
    -0
      requirements.txt
  4. +9
    -0
      requirements/multi-modal.txt
  5. +35
    -0
      tests/pipelines/test_image_captioning.py

+ 1
- 0
maas_lib/pipelines/multi_modal/__init__.py View File

@@ -0,0 +1 @@
from .image_captioning import ImageCaptionPipeline

+ 118
- 0
maas_lib/pipelines/multi_modal/image_captioning.py View File

@@ -0,0 +1,118 @@
from typing import Any, Dict

import numpy as np
import torch
from fairseq import checkpoint_utils, tasks, utils
from ofa.models.ofa import OFAModel
from ofa.tasks.mm_tasks import CaptionTask
from ofa.utils.eval_utils import eval_caption
from PIL import Image

from maas_lib.pipelines.base import Input
from maas_lib.preprocessors import load_image
from maas_lib.utils.constant import Tasks
from maas_lib.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES

logger = get_logger()


@PIPELINES.register_module(Tasks.image_captioning, module_name='ofa')
class ImageCaptionPipeline(Pipeline):
# TODO: refine using modelhub
def __init__(self, model: str, bpe_dir: str):
super().__init__()
# turn on cuda if GPU is available

tasks.register_task('caption', CaptionTask)
use_cuda = False
# use fp16 only when GPU is available
use_fp16 = False
overrides = {
'bpe_dir': bpe_dir,
'eval_cider': False,
'beam': 5,
'max_len_b': 16,
'no_repeat_ngram_size': 3,
'seed': 7
}
models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
utils.split_paths(model), arg_overrides=overrides)

# Move models to GPU
for model in models:
model.eval()
if use_cuda:
model.cuda()
if use_fp16:
model.half()
model.prepare_for_inference_(cfg)
self.models = models
# Initialize generator
self.generator = task.build_generator(models, cfg.generation)

# Initialize transform
from torchvision import transforms
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

self.patch_resize_transform = transforms.Compose([
lambda image: image.convert('RGB'),
transforms.Resize(
(cfg.task.patch_image_size, cfg.task.patch_image_size),
interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])

self.task = task
self.bos_item = torch.LongTensor([task.src_dict.bos()])
self.eos_item = torch.LongTensor([task.src_dict.eos()])
self.pad_idx = task.src_dict.pad()

def preprocess(self, input: Input) -> Dict[str, Any]:

def encode_text(text, length=None, append_bos=False, append_eos=False):
s = self.task.tgt_dict.encode_line(
line=self.task.bpe.encode(text),
add_if_not_exist=False,
append_eos=False).long()
if length is not None:
s = s[:length]
if append_bos:
s = torch.cat([self.bos_item, s])
if append_eos:
s = torch.cat([s, self.eos_item])
return s

patch_image = self.patch_resize_transform(
load_image(input)).unsqueeze(0)
patch_mask = torch.tensor([True])
text = 'what does the image describe?'
src_text = encode_text(
text, append_bos=True, append_eos=True).unsqueeze(0)
src_length = torch.LongTensor(
[s.ne(self.pad_idx).long().sum() for s in src_text])
sample = {
'id': np.array(['42']),
'net_input': {
'src_tokens': src_text,
'src_lengths': src_length,
'patch_images': patch_image,
'patch_masks': patch_mask,
}
}
return sample

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
results, _ = eval_caption(self.task, self.generator, self.models,
input)
return {
'image_id': results[0]['image_id'],
'caption': results[0]['caption']
}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
# What should we do here ?
return inputs

+ 1
- 0
requirements.txt View File

@@ -1,3 +1,4 @@
-r requirements/runtime.txt
-r requirements/pipeline.txt
-r requirements/multi-modal.txt
-r requirements/nlp.txt

+ 9
- 0
requirements/multi-modal.txt View File

@@ -0,0 +1,9 @@
datasets
einops
ftfy>=6.0.3
https://jirenmr.oss-cn-zhangjiakou.aliyuncs.com/ofa/fairseq-maas-py3-none-any.whl
https://jirenmr.oss-cn-zhangjiakou.aliyuncs.com/ofa/ofa-0.0.2-py3-none-any.whl
pycocoevalcap>=1.2
pycocotools>=2.0.4
rouge_score
timm

+ 35
- 0
tests/pipelines/test_image_captioning.py View File

@@ -0,0 +1,35 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import tempfile
import unittest

from maas_lib.fileio import File
from maas_lib.pipelines import pipeline
from maas_lib.utils.constant import Tasks


class ImageCaptionTest(unittest.TestCase):

def test_run(self):
model = 'https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt'

os.system(
'wget https://jirenmr.oss-cn-zhangjiakou.aliyuncs.com/ofa/BPE.zip'
)
os.system('unzip BPE.zip')
bpe_dir = './BPE'

with tempfile.NamedTemporaryFile('wb', suffix='.pb') as ofile:
ofile.write(File.read(model))
img_captioning = pipeline(
Tasks.image_captioning, model=ofile.name, bpe_dir=bpe_dir)

result = img_captioning(
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png'
)
print(result['caption'])


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save