添加 MPLUG 模型的 visual question answering 任务 pipeline Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9182119master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:b37b706885849037b5fa7fa44a3b78a6375f768d95ce46bfcb8e7329d038a692 | |||
| size 181725 | |||
| @@ -27,6 +27,7 @@ class Models(object): | |||
| # multi-modal models | |||
| ofa = 'ofa' | |||
| clip = 'clip-multi-modal-embedding' | |||
| mplug = 'mplug' | |||
| class Pipelines(object): | |||
| @@ -63,6 +64,7 @@ class Pipelines(object): | |||
| # multi-modal tasks | |||
| image_caption = 'image-caption' | |||
| multi_modal_embedding = 'multi-modal-embedding' | |||
| visual_question_answering = 'visual-question-answering' | |||
| class Trainers(object): | |||
| @@ -105,3 +107,4 @@ class Preprocessors(object): | |||
| # multi-modal | |||
| ofa_image_caption = 'ofa-image-caption' | |||
| mplug_visual_question_answering = 'mplug-visual-question-answering' | |||
| @@ -1,2 +1,4 @@ | |||
| from .clip.clip_model import CLIPForMultiModalEmbedding | |||
| from .image_captioning_model import OfaForImageCaptioning | |||
| from .mplug_for_visual_question_answering import \ | |||
| MPlugForVisualQuestionAnswering | |||
| @@ -0,0 +1,46 @@ | |||
| from typing import Dict | |||
| from ...metainfo import Models | |||
| from ...utils.constant import Tasks | |||
| from ..base import Model, Tensor | |||
| from ..builder import MODELS | |||
| __all__ = ['MPlugForVisualQuestionAnswering'] | |||
| @MODELS.register_module( | |||
| Tasks.visual_question_answering, module_name=Models.mplug) | |||
| class MPlugForVisualQuestionAnswering(Model): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the mplug model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| from sofa.models.mplug import MPlugForVisualQuestionAnswering | |||
| self.model = MPlugForVisualQuestionAnswering.from_pretrained(model_dir) | |||
| self.tokenizer = self.model.tokenizer | |||
| def train(self): | |||
| return self.model.train() | |||
| def eval(self): | |||
| return self.model.eval() | |||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
| """return the result by the model | |||
| Args: | |||
| input (Dict[str, Tensor]): the preprocessed data | |||
| Returns: | |||
| Dict[str, Tensor]: results | |||
| Example: | |||
| { | |||
| 'predictions': Tensor([[1377, 4959, 2785, 6392...])]), | |||
| } | |||
| """ | |||
| return self.model(**input)[0] | |||
| @@ -42,7 +42,10 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/cv_TAdaConv_action-recognition'), | |||
| Tasks.multi_modal_embedding: | |||
| (Pipelines.multi_modal_embedding, | |||
| 'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding') | |||
| 'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding'), | |||
| Tasks.visual_question_answering: | |||
| (Pipelines.visual_question_answering, | |||
| 'damo/mplug_visual-question-answering_coco_large_en'), | |||
| } | |||
| @@ -1,2 +1,3 @@ | |||
| from .image_captioning_pipeline import ImageCaptionPipeline | |||
| from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline | |||
| from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline | |||
| @@ -0,0 +1,65 @@ | |||
| from typing import Any, Dict, Optional, Union | |||
| import torch | |||
| from ...metainfo import Pipelines | |||
| from ...models import Model | |||
| from ...models.multi_modal import MPlugForVisualQuestionAnswering | |||
| from ...preprocessors import MPlugVisualQuestionAnsweringPreprocessor | |||
| from ...utils.constant import Tasks | |||
| from ..base import Pipeline, Tensor | |||
| from ..builder import PIPELINES | |||
| __all__ = ['VisualQuestionAnsweringPipeline'] | |||
| @PIPELINES.register_module( | |||
| Tasks.visual_question_answering, | |||
| module_name=Pipelines.visual_question_answering) | |||
| class VisualQuestionAnsweringPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[MPlugForVisualQuestionAnswering, str], | |||
| preprocessor: Optional[ | |||
| MPlugVisualQuestionAnsweringPreprocessor] = None, | |||
| **kwargs): | |||
| """use `model` and `preprocessor` to create a visual question answering pipeline for prediction | |||
| Args: | |||
| model (MPlugForVisualQuestionAnswering): a model instance | |||
| preprocessor (MPlugVisualQuestionAnsweringPreprocessor): a preprocessor instance | |||
| """ | |||
| model = model if isinstance( | |||
| model, | |||
| MPlugForVisualQuestionAnswering) else Model.from_pretrained(model) | |||
| if preprocessor is None: | |||
| preprocessor = MPlugVisualQuestionAnsweringPreprocessor( | |||
| model.model_dir) | |||
| model.eval() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| self.tokenizer = model.tokenizer | |||
| def forward(self, inputs: Dict[str, Any], | |||
| **forward_params) -> Dict[str, Any]: | |||
| with torch.no_grad(): | |||
| return super().forward(inputs, **forward_params) | |||
| def postprocess(self, inputs: Dict[str, Tensor], | |||
| **postprocess_params) -> Dict[str, str]: | |||
| """process the prediction results | |||
| Args: | |||
| inputs (Dict[str, Any]): _description_ | |||
| Returns: | |||
| Dict[str, str]: the prediction results | |||
| """ | |||
| replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), | |||
| ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | |||
| ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | |||
| pred_string = self.tokenizer.decode(inputs[0][0]) | |||
| for _old, _new in replace_tokens_bert: | |||
| pred_string = pred_string.replace(_old, _new) | |||
| pred_string.strip() | |||
| return {'answer': pred_string} | |||
| @@ -6,6 +6,6 @@ from .builder import PREPROCESSORS, build_preprocessor | |||
| from .common import Compose | |||
| from .image import LoadImage, load_image | |||
| from .kws import WavToLists | |||
| from .multi_modal import OfaImageCaptionPreprocessor | |||
| from .multi_modal import * # noqa F403 | |||
| from .nlp import * # noqa F403 | |||
| from .text_to_speech import * # noqa F403 | |||
| @@ -16,6 +16,7 @@ from .image import load_image | |||
| __all__ = [ | |||
| 'OfaImageCaptionPreprocessor', | |||
| 'MPlugVisualQuestionAnsweringPreprocessor', | |||
| ] | |||
| @@ -110,3 +111,47 @@ class OfaImageCaptionPreprocessor(Preprocessor): | |||
| } | |||
| } | |||
| return sample | |||
| @PREPROCESSORS.register_module( | |||
| Fields.multi_modal, | |||
| module_name=Preprocessors.mplug_visual_question_answering) | |||
| class MPlugVisualQuestionAnsweringPreprocessor(Preprocessor): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """preprocess the data via 'bert-base-uncased' tokenizer and configuration | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| # tokenizer | |||
| from transformers import AutoTokenizer | |||
| self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |||
| # load configuration | |||
| from sofa.models.mplug import CONFIG_NAME, MPlugConfig | |||
| config = MPlugConfig.from_yaml_file(osp.join(model_dir, CONFIG_NAME)) | |||
| # Initialize transform | |||
| from torchvision import transforms | |||
| mean = (0.48145466, 0.4578275, 0.40821073) | |||
| std = (0.26862954, 0.26130258, 0.27577711) | |||
| self.patch_resize_transform = transforms.Compose([ | |||
| transforms.Resize((config.image_res, config.image_res), | |||
| interpolation=Image.BICUBIC), | |||
| transforms.ToTensor(), | |||
| transforms.Normalize(mean=mean, std=std), | |||
| ]) | |||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
| image, question = data['image'], data['question'] | |||
| image = Image.open(image).convert('RGB') if isinstance(image, | |||
| str) else image | |||
| image = self.patch_resize_transform(image) | |||
| image = torch.stack([image], dim=0) | |||
| question = self.tokenizer([question.lower()], | |||
| padding='longest', | |||
| return_tensors='pt') | |||
| return {'image': image, 'question': question, 'train': False} | |||
| @@ -61,6 +61,7 @@ class Tasks(object): | |||
| visual_grounding = 'visual-grounding' | |||
| text_to_image_synthesis = 'text-to-image-synthesis' | |||
| multi_modal_embedding = 'multi-modal-embedding' | |||
| visual_question_answering = 'visual-question-answering' | |||
| class InputFields(object): | |||
| @@ -1 +1 @@ | |||
| https://alinlp.alibaba-inc.com/pypi/sofa-1.0.3-py3-none-any.whl | |||
| https://alinlp.alibaba-inc.com/pypi/sofa-1.0.4.1-py3-none-any.whl | |||
| @@ -0,0 +1,60 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models import Model | |||
| from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering | |||
| from modelscope.pipelines import VisualQuestionAnsweringPipeline, pipeline | |||
| from modelscope.preprocessors import MPlugVisualQuestionAnsweringPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class VisualQuestionAnsweringTest(unittest.TestCase): | |||
| model_id = 'damo/mplug_visual-question-answering_coco_large_en' | |||
| input_vqa = { | |||
| 'image': 'data/test/images/image_mplug_vqa.jpg', | |||
| 'question': 'What is the woman doing?', | |||
| } | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run(self): | |||
| cache_path = snapshot_download(self.model_id) | |||
| preprocessor = MPlugVisualQuestionAnsweringPreprocessor(cache_path) | |||
| model = MPlugForVisualQuestionAnswering(cache_path) | |||
| pipeline1 = VisualQuestionAnsweringPipeline( | |||
| model, preprocessor=preprocessor) | |||
| pipeline2 = pipeline( | |||
| Tasks.visual_question_answering, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| print(f"question: {self.input_vqa['question']}") | |||
| print(f"pipeline1: {pipeline1(self.input_vqa)['answer']}") | |||
| print(f"pipeline2: {pipeline2(self.input_vqa)['answer']}") | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| preprocessor = MPlugVisualQuestionAnsweringPreprocessor( | |||
| model.model_dir) | |||
| pipeline_vqa = pipeline( | |||
| task=Tasks.visual_question_answering, | |||
| model=model, | |||
| preprocessor=preprocessor) | |||
| print(pipeline_vqa(self.input_vqa)) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_model_name(self): | |||
| pipeline_vqa = pipeline( | |||
| Tasks.visual_question_answering, model=self.model_id) | |||
| print(pipeline_vqa(self.input_vqa)) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_default_model(self): | |||
| pipeline_vqa = pipeline(task=Tasks.visual_question_answering) | |||
| print(pipeline_vqa(self.input_vqa)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||