diff --git a/data/test/images/image_mplug_vqa.jpg b/data/test/images/image_mplug_vqa.jpg new file mode 100644 index 00000000..57919471 --- /dev/null +++ b/data/test/images/image_mplug_vqa.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b37b706885849037b5fa7fa44a3b78a6375f768d95ce46bfcb8e7329d038a692 +size 181725 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 1d2ee4d2..485605bb 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -27,6 +27,7 @@ class Models(object): # multi-modal models ofa = 'ofa' clip = 'clip-multi-modal-embedding' + mplug = 'mplug' class Pipelines(object): @@ -63,6 +64,7 @@ class Pipelines(object): # multi-modal tasks image_caption = 'image-caption' multi_modal_embedding = 'multi-modal-embedding' + visual_question_answering = 'visual-question-answering' class Trainers(object): @@ -105,3 +107,4 @@ class Preprocessors(object): # multi-modal ofa_image_caption = 'ofa-image-caption' + mplug_visual_question_answering = 'mplug-visual-question-answering' diff --git a/modelscope/models/multi_modal/__init__.py b/modelscope/models/multi_modal/__init__.py index 2e6cc3bf..4ed9809b 100644 --- a/modelscope/models/multi_modal/__init__.py +++ b/modelscope/models/multi_modal/__init__.py @@ -1,2 +1,4 @@ from .clip.clip_model import CLIPForMultiModalEmbedding from .image_captioning_model import OfaForImageCaptioning +from .mplug_for_visual_question_answering import \ + MPlugForVisualQuestionAnswering diff --git a/modelscope/models/multi_modal/mplug_for_visual_question_answering.py b/modelscope/models/multi_modal/mplug_for_visual_question_answering.py new file mode 100644 index 00000000..2682c048 --- /dev/null +++ b/modelscope/models/multi_modal/mplug_for_visual_question_answering.py @@ -0,0 +1,46 @@ +from typing import Dict + +from ...metainfo import Models +from ...utils.constant import Tasks +from ..base import Model, Tensor +from ..builder import MODELS + +__all__ = ['MPlugForVisualQuestionAnswering'] + + +@MODELS.register_module( + Tasks.visual_question_answering, module_name=Models.mplug) +class MPlugForVisualQuestionAnswering(Model): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the mplug model from the `model_dir` path. + Args: + model_dir (str): the model path. + """ + + super().__init__(model_dir, *args, **kwargs) + from sofa.models.mplug import MPlugForVisualQuestionAnswering + self.model = MPlugForVisualQuestionAnswering.from_pretrained(model_dir) + self.tokenizer = self.model.tokenizer + + def train(self): + return self.model.train() + + def eval(self): + return self.model.eval() + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Tensor]: results + Example: + { + 'predictions': Tensor([[1377, 4959, 2785, 6392...])]), + } + """ + + return self.model(**input)[0] diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 847955d4..2f66682d 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -42,7 +42,10 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_TAdaConv_action-recognition'), Tasks.multi_modal_embedding: (Pipelines.multi_modal_embedding, - 'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding') + 'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding'), + Tasks.visual_question_answering: + (Pipelines.visual_question_answering, + 'damo/mplug_visual-question-answering_coco_large_en'), } diff --git a/modelscope/pipelines/multi_modal/__init__.py b/modelscope/pipelines/multi_modal/__init__.py index 6c96d843..fdcada89 100644 --- a/modelscope/pipelines/multi_modal/__init__.py +++ b/modelscope/pipelines/multi_modal/__init__.py @@ -1,2 +1,3 @@ from .image_captioning_pipeline import ImageCaptionPipeline from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline +from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline diff --git a/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py b/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py new file mode 100644 index 00000000..97c8cf7b --- /dev/null +++ b/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py @@ -0,0 +1,65 @@ +from typing import Any, Dict, Optional, Union + +import torch + +from ...metainfo import Pipelines +from ...models import Model +from ...models.multi_modal import MPlugForVisualQuestionAnswering +from ...preprocessors import MPlugVisualQuestionAnsweringPreprocessor +from ...utils.constant import Tasks +from ..base import Pipeline, Tensor +from ..builder import PIPELINES + +__all__ = ['VisualQuestionAnsweringPipeline'] + + +@PIPELINES.register_module( + Tasks.visual_question_answering, + module_name=Pipelines.visual_question_answering) +class VisualQuestionAnsweringPipeline(Pipeline): + + def __init__(self, + model: Union[MPlugForVisualQuestionAnswering, str], + preprocessor: Optional[ + MPlugVisualQuestionAnsweringPreprocessor] = None, + **kwargs): + """use `model` and `preprocessor` to create a visual question answering pipeline for prediction + + Args: + model (MPlugForVisualQuestionAnswering): a model instance + preprocessor (MPlugVisualQuestionAnsweringPreprocessor): a preprocessor instance + """ + model = model if isinstance( + model, + MPlugForVisualQuestionAnswering) else Model.from_pretrained(model) + if preprocessor is None: + preprocessor = MPlugVisualQuestionAnsweringPreprocessor( + model.model_dir) + model.eval() + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.tokenizer = model.tokenizer + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return super().forward(inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Tensor], + **postprocess_params) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), + ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), + ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) + + pred_string = self.tokenizer.decode(inputs[0][0]) + for _old, _new in replace_tokens_bert: + pred_string = pred_string.replace(_old, _new) + pred_string.strip() + return {'answer': pred_string} diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 1bc06ce3..694688f6 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -6,6 +6,6 @@ from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image from .kws import WavToLists -from .multi_modal import OfaImageCaptionPreprocessor +from .multi_modal import * # noqa F403 from .nlp import * # noqa F403 from .text_to_speech import * # noqa F403 diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 7c8f0fab..1bc686eb 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -16,6 +16,7 @@ from .image import load_image __all__ = [ 'OfaImageCaptionPreprocessor', + 'MPlugVisualQuestionAnsweringPreprocessor', ] @@ -110,3 +111,47 @@ class OfaImageCaptionPreprocessor(Preprocessor): } } return sample + + +@PREPROCESSORS.register_module( + Fields.multi_modal, + module_name=Preprocessors.mplug_visual_question_answering) +class MPlugVisualQuestionAnsweringPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, *args, **kwargs): + """preprocess the data via 'bert-base-uncased' tokenizer and configuration + + """ + super().__init__(*args, **kwargs) + + # tokenizer + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') + + # load configuration + from sofa.models.mplug import CONFIG_NAME, MPlugConfig + config = MPlugConfig.from_yaml_file(osp.join(model_dir, CONFIG_NAME)) + + # Initialize transform + from torchvision import transforms + mean = (0.48145466, 0.4578275, 0.40821073) + std = (0.26862954, 0.26130258, 0.27577711) + + self.patch_resize_transform = transforms.Compose([ + transforms.Resize((config.image_res, config.image_res), + interpolation=Image.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + + def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: + image, question = data['image'], data['question'] + image = Image.open(image).convert('RGB') if isinstance(image, + str) else image + image = self.patch_resize_transform(image) + image = torch.stack([image], dim=0) + question = self.tokenizer([question.lower()], + padding='longest', + return_tensors='pt') + + return {'image': image, 'question': question, 'train': False} diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 44bd1dff..3ce3ab98 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -61,6 +61,7 @@ class Tasks(object): visual_grounding = 'visual-grounding' text_to_image_synthesis = 'text-to-image-synthesis' multi_modal_embedding = 'multi-modal-embedding' + visual_question_answering = 'visual-question-answering' class InputFields(object): diff --git a/requirements/nlp.txt b/requirements/nlp.txt index 261b9ec5..574bf856 100644 --- a/requirements/nlp.txt +++ b/requirements/nlp.txt @@ -1 +1 @@ -https://alinlp.alibaba-inc.com/pypi/sofa-1.0.3-py3-none-any.whl +https://alinlp.alibaba-inc.com/pypi/sofa-1.0.4.1-py3-none-any.whl diff --git a/tests/pipelines/test_visual_question_answering.py b/tests/pipelines/test_visual_question_answering.py new file mode 100644 index 00000000..4577607e --- /dev/null +++ b/tests/pipelines/test_visual_question_answering.py @@ -0,0 +1,60 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering +from modelscope.pipelines import VisualQuestionAnsweringPipeline, pipeline +from modelscope.preprocessors import MPlugVisualQuestionAnsweringPreprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class VisualQuestionAnsweringTest(unittest.TestCase): + model_id = 'damo/mplug_visual-question-answering_coco_large_en' + input_vqa = { + 'image': 'data/test/images/image_mplug_vqa.jpg', + 'question': 'What is the woman doing?', + } + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run(self): + cache_path = snapshot_download(self.model_id) + preprocessor = MPlugVisualQuestionAnsweringPreprocessor(cache_path) + model = MPlugForVisualQuestionAnswering(cache_path) + pipeline1 = VisualQuestionAnsweringPipeline( + model, preprocessor=preprocessor) + pipeline2 = pipeline( + Tasks.visual_question_answering, + model=model, + preprocessor=preprocessor) + print(f"question: {self.input_vqa['question']}") + print(f"pipeline1: {pipeline1(self.input_vqa)['answer']}") + print(f"pipeline2: {pipeline2(self.input_vqa)['answer']}") + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + model = Model.from_pretrained(self.model_id) + preprocessor = MPlugVisualQuestionAnsweringPreprocessor( + model.model_dir) + pipeline_vqa = pipeline( + task=Tasks.visual_question_answering, + model=model, + preprocessor=preprocessor) + print(pipeline_vqa(self.input_vqa)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + pipeline_vqa = pipeline( + Tasks.visual_question_answering, model=self.model_id) + print(pipeline_vqa(self.input_vqa)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_vqa = pipeline(task=Tasks.visual_question_answering) + print(pipeline_vqa(self.input_vqa)) + + +if __name__ == '__main__': + unittest.main()