diff --git a/maas_lib/pipelines/multi_modal/__init__.py b/maas_lib/pipelines/multi_modal/__init__.py index e69de29b..7d9a2c59 100644 --- a/maas_lib/pipelines/multi_modal/__init__.py +++ b/maas_lib/pipelines/multi_modal/__init__.py @@ -0,0 +1 @@ +from .image_captioning import ImageCaptionPipeline diff --git a/maas_lib/pipelines/multi_modal/image_captioning.py b/maas_lib/pipelines/multi_modal/image_captioning.py new file mode 100644 index 00000000..778354b7 --- /dev/null +++ b/maas_lib/pipelines/multi_modal/image_captioning.py @@ -0,0 +1,118 @@ +from typing import Any, Dict + +import numpy as np +import torch +from fairseq import checkpoint_utils, tasks, utils +from ofa.models.ofa import OFAModel +from ofa.tasks.mm_tasks import CaptionTask +from ofa.utils.eval_utils import eval_caption +from PIL import Image + +from maas_lib.pipelines.base import Input +from maas_lib.preprocessors import load_image +from maas_lib.utils.constant import Tasks +from maas_lib.utils.logger import get_logger +from ..base import Pipeline +from ..builder import PIPELINES + +logger = get_logger() + + +@PIPELINES.register_module(Tasks.image_captioning, module_name='ofa') +class ImageCaptionPipeline(Pipeline): + # TODO: refine using modelhub + def __init__(self, model: str, bpe_dir: str): + super().__init__() + # turn on cuda if GPU is available + + tasks.register_task('caption', CaptionTask) + use_cuda = False + # use fp16 only when GPU is available + use_fp16 = False + overrides = { + 'bpe_dir': bpe_dir, + 'eval_cider': False, + 'beam': 5, + 'max_len_b': 16, + 'no_repeat_ngram_size': 3, + 'seed': 7 + } + models, cfg, task = checkpoint_utils.load_model_ensemble_and_task( + utils.split_paths(model), arg_overrides=overrides) + + # Move models to GPU + for model in models: + model.eval() + if use_cuda: + model.cuda() + if use_fp16: + model.half() + model.prepare_for_inference_(cfg) + self.models = models + # Initialize generator + self.generator = task.build_generator(models, cfg.generation) + + # Initialize transform + from torchvision import transforms + mean = [0.5, 0.5, 0.5] + std = [0.5, 0.5, 0.5] + + self.patch_resize_transform = transforms.Compose([ + lambda image: image.convert('RGB'), + transforms.Resize( + (cfg.task.patch_image_size, cfg.task.patch_image_size), + interpolation=Image.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + + self.task = task + self.bos_item = torch.LongTensor([task.src_dict.bos()]) + self.eos_item = torch.LongTensor([task.src_dict.eos()]) + self.pad_idx = task.src_dict.pad() + + def preprocess(self, input: Input) -> Dict[str, Any]: + + def encode_text(text, length=None, append_bos=False, append_eos=False): + s = self.task.tgt_dict.encode_line( + line=self.task.bpe.encode(text), + add_if_not_exist=False, + append_eos=False).long() + if length is not None: + s = s[:length] + if append_bos: + s = torch.cat([self.bos_item, s]) + if append_eos: + s = torch.cat([s, self.eos_item]) + return s + + patch_image = self.patch_resize_transform( + load_image(input)).unsqueeze(0) + patch_mask = torch.tensor([True]) + text = 'what does the image describe?' + src_text = encode_text( + text, append_bos=True, append_eos=True).unsqueeze(0) + src_length = torch.LongTensor( + [s.ne(self.pad_idx).long().sum() for s in src_text]) + sample = { + 'id': np.array(['42']), + 'net_input': { + 'src_tokens': src_text, + 'src_lengths': src_length, + 'patch_images': patch_image, + 'patch_masks': patch_mask, + } + } + return sample + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + results, _ = eval_caption(self.task, self.generator, self.models, + input) + return { + 'image_id': results[0]['image_id'], + 'caption': results[0]['caption'] + } + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + # What should we do here ? + return inputs diff --git a/requirements.txt b/requirements.txt index 3cc6857e..1944b476 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ -r requirements/runtime.txt -r requirements/pipeline.txt +-r requirements/multi-modal.txt -r requirements/nlp.txt diff --git a/requirements/multi-modal.txt b/requirements/multi-modal.txt new file mode 100644 index 00000000..ad641b63 --- /dev/null +++ b/requirements/multi-modal.txt @@ -0,0 +1,9 @@ +datasets +einops +ftfy>=6.0.3 +https://jirenmr.oss-cn-zhangjiakou.aliyuncs.com/ofa/fairseq-maas-py3-none-any.whl +https://jirenmr.oss-cn-zhangjiakou.aliyuncs.com/ofa/ofa-0.0.2-py3-none-any.whl +pycocoevalcap>=1.2 +pycocotools>=2.0.4 +rouge_score +timm diff --git a/tests/pipelines/test_image_captioning.py b/tests/pipelines/test_image_captioning.py new file mode 100644 index 00000000..f951f0a8 --- /dev/null +++ b/tests/pipelines/test_image_captioning.py @@ -0,0 +1,35 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import tempfile +import unittest + +from maas_lib.fileio import File +from maas_lib.pipelines import pipeline +from maas_lib.utils.constant import Tasks + + +class ImageCaptionTest(unittest.TestCase): + + def test_run(self): + model = 'https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt' + + os.system( + 'wget https://jirenmr.oss-cn-zhangjiakou.aliyuncs.com/ofa/BPE.zip' + ) + os.system('unzip BPE.zip') + bpe_dir = './BPE' + + with tempfile.NamedTemporaryFile('wb', suffix='.pb') as ofile: + ofile.write(File.read(model)) + img_captioning = pipeline( + Tasks.image_captioning, model=ofile.name, bpe_dir=bpe_dir) + + result = img_captioning( + 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' + ) + print(result['caption']) + + +if __name__ == '__main__': + unittest.main()