添加 MPLUG 模型的 visual question answering 任务 pipeline Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9182119master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:b37b706885849037b5fa7fa44a3b78a6375f768d95ce46bfcb8e7329d038a692 | |||||
| size 181725 | |||||
| @@ -27,6 +27,7 @@ class Models(object): | |||||
| # multi-modal models | # multi-modal models | ||||
| ofa = 'ofa' | ofa = 'ofa' | ||||
| clip = 'clip-multi-modal-embedding' | clip = 'clip-multi-modal-embedding' | ||||
| mplug = 'mplug' | |||||
| class Pipelines(object): | class Pipelines(object): | ||||
| @@ -63,6 +64,7 @@ class Pipelines(object): | |||||
| # multi-modal tasks | # multi-modal tasks | ||||
| image_caption = 'image-caption' | image_caption = 'image-caption' | ||||
| multi_modal_embedding = 'multi-modal-embedding' | multi_modal_embedding = 'multi-modal-embedding' | ||||
| visual_question_answering = 'visual-question-answering' | |||||
| class Trainers(object): | class Trainers(object): | ||||
| @@ -105,3 +107,4 @@ class Preprocessors(object): | |||||
| # multi-modal | # multi-modal | ||||
| ofa_image_caption = 'ofa-image-caption' | ofa_image_caption = 'ofa-image-caption' | ||||
| mplug_visual_question_answering = 'mplug-visual-question-answering' | |||||
| @@ -1,2 +1,4 @@ | |||||
| from .clip.clip_model import CLIPForMultiModalEmbedding | from .clip.clip_model import CLIPForMultiModalEmbedding | ||||
| from .image_captioning_model import OfaForImageCaptioning | from .image_captioning_model import OfaForImageCaptioning | ||||
| from .mplug_for_visual_question_answering import \ | |||||
| MPlugForVisualQuestionAnswering | |||||
| @@ -0,0 +1,46 @@ | |||||
| from typing import Dict | |||||
| from ...metainfo import Models | |||||
| from ...utils.constant import Tasks | |||||
| from ..base import Model, Tensor | |||||
| from ..builder import MODELS | |||||
| __all__ = ['MPlugForVisualQuestionAnswering'] | |||||
| @MODELS.register_module( | |||||
| Tasks.visual_question_answering, module_name=Models.mplug) | |||||
| class MPlugForVisualQuestionAnswering(Model): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the mplug model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| from sofa.models.mplug import MPlugForVisualQuestionAnswering | |||||
| self.model = MPlugForVisualQuestionAnswering.from_pretrained(model_dir) | |||||
| self.tokenizer = self.model.tokenizer | |||||
| def train(self): | |||||
| return self.model.train() | |||||
| def eval(self): | |||||
| return self.model.eval() | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| """return the result by the model | |||||
| Args: | |||||
| input (Dict[str, Tensor]): the preprocessed data | |||||
| Returns: | |||||
| Dict[str, Tensor]: results | |||||
| Example: | |||||
| { | |||||
| 'predictions': Tensor([[1377, 4959, 2785, 6392...])]), | |||||
| } | |||||
| """ | |||||
| return self.model(**input)[0] | |||||
| @@ -42,7 +42,10 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/cv_TAdaConv_action-recognition'), | 'damo/cv_TAdaConv_action-recognition'), | ||||
| Tasks.multi_modal_embedding: | Tasks.multi_modal_embedding: | ||||
| (Pipelines.multi_modal_embedding, | (Pipelines.multi_modal_embedding, | ||||
| 'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding') | |||||
| 'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding'), | |||||
| Tasks.visual_question_answering: | |||||
| (Pipelines.visual_question_answering, | |||||
| 'damo/mplug_visual-question-answering_coco_large_en'), | |||||
| } | } | ||||
| @@ -1,2 +1,3 @@ | |||||
| from .image_captioning_pipeline import ImageCaptionPipeline | from .image_captioning_pipeline import ImageCaptionPipeline | ||||
| from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline | from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline | ||||
| from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline | |||||
| @@ -0,0 +1,65 @@ | |||||
| from typing import Any, Dict, Optional, Union | |||||
| import torch | |||||
| from ...metainfo import Pipelines | |||||
| from ...models import Model | |||||
| from ...models.multi_modal import MPlugForVisualQuestionAnswering | |||||
| from ...preprocessors import MPlugVisualQuestionAnsweringPreprocessor | |||||
| from ...utils.constant import Tasks | |||||
| from ..base import Pipeline, Tensor | |||||
| from ..builder import PIPELINES | |||||
| __all__ = ['VisualQuestionAnsweringPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| Tasks.visual_question_answering, | |||||
| module_name=Pipelines.visual_question_answering) | |||||
| class VisualQuestionAnsweringPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[MPlugForVisualQuestionAnswering, str], | |||||
| preprocessor: Optional[ | |||||
| MPlugVisualQuestionAnsweringPreprocessor] = None, | |||||
| **kwargs): | |||||
| """use `model` and `preprocessor` to create a visual question answering pipeline for prediction | |||||
| Args: | |||||
| model (MPlugForVisualQuestionAnswering): a model instance | |||||
| preprocessor (MPlugVisualQuestionAnsweringPreprocessor): a preprocessor instance | |||||
| """ | |||||
| model = model if isinstance( | |||||
| model, | |||||
| MPlugForVisualQuestionAnswering) else Model.from_pretrained(model) | |||||
| if preprocessor is None: | |||||
| preprocessor = MPlugVisualQuestionAnsweringPreprocessor( | |||||
| model.model_dir) | |||||
| model.eval() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| self.tokenizer = model.tokenizer | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| with torch.no_grad(): | |||||
| return super().forward(inputs, **forward_params) | |||||
| def postprocess(self, inputs: Dict[str, Tensor], | |||||
| **postprocess_params) -> Dict[str, str]: | |||||
| """process the prediction results | |||||
| Args: | |||||
| inputs (Dict[str, Any]): _description_ | |||||
| Returns: | |||||
| Dict[str, str]: the prediction results | |||||
| """ | |||||
| replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), | |||||
| ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), | |||||
| ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) | |||||
| pred_string = self.tokenizer.decode(inputs[0][0]) | |||||
| for _old, _new in replace_tokens_bert: | |||||
| pred_string = pred_string.replace(_old, _new) | |||||
| pred_string.strip() | |||||
| return {'answer': pred_string} | |||||
| @@ -6,6 +6,6 @@ from .builder import PREPROCESSORS, build_preprocessor | |||||
| from .common import Compose | from .common import Compose | ||||
| from .image import LoadImage, load_image | from .image import LoadImage, load_image | ||||
| from .kws import WavToLists | from .kws import WavToLists | ||||
| from .multi_modal import OfaImageCaptionPreprocessor | |||||
| from .multi_modal import * # noqa F403 | |||||
| from .nlp import * # noqa F403 | from .nlp import * # noqa F403 | ||||
| from .text_to_speech import * # noqa F403 | from .text_to_speech import * # noqa F403 | ||||
| @@ -16,6 +16,7 @@ from .image import load_image | |||||
| __all__ = [ | __all__ = [ | ||||
| 'OfaImageCaptionPreprocessor', | 'OfaImageCaptionPreprocessor', | ||||
| 'MPlugVisualQuestionAnsweringPreprocessor', | |||||
| ] | ] | ||||
| @@ -110,3 +111,47 @@ class OfaImageCaptionPreprocessor(Preprocessor): | |||||
| } | } | ||||
| } | } | ||||
| return sample | return sample | ||||
| @PREPROCESSORS.register_module( | |||||
| Fields.multi_modal, | |||||
| module_name=Preprocessors.mplug_visual_question_answering) | |||||
| class MPlugVisualQuestionAnsweringPreprocessor(Preprocessor): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """preprocess the data via 'bert-base-uncased' tokenizer and configuration | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| # tokenizer | |||||
| from transformers import AutoTokenizer | |||||
| self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |||||
| # load configuration | |||||
| from sofa.models.mplug import CONFIG_NAME, MPlugConfig | |||||
| config = MPlugConfig.from_yaml_file(osp.join(model_dir, CONFIG_NAME)) | |||||
| # Initialize transform | |||||
| from torchvision import transforms | |||||
| mean = (0.48145466, 0.4578275, 0.40821073) | |||||
| std = (0.26862954, 0.26130258, 0.27577711) | |||||
| self.patch_resize_transform = transforms.Compose([ | |||||
| transforms.Resize((config.image_res, config.image_res), | |||||
| interpolation=Image.BICUBIC), | |||||
| transforms.ToTensor(), | |||||
| transforms.Normalize(mean=mean, std=std), | |||||
| ]) | |||||
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
| image, question = data['image'], data['question'] | |||||
| image = Image.open(image).convert('RGB') if isinstance(image, | |||||
| str) else image | |||||
| image = self.patch_resize_transform(image) | |||||
| image = torch.stack([image], dim=0) | |||||
| question = self.tokenizer([question.lower()], | |||||
| padding='longest', | |||||
| return_tensors='pt') | |||||
| return {'image': image, 'question': question, 'train': False} | |||||
| @@ -61,6 +61,7 @@ class Tasks(object): | |||||
| visual_grounding = 'visual-grounding' | visual_grounding = 'visual-grounding' | ||||
| text_to_image_synthesis = 'text-to-image-synthesis' | text_to_image_synthesis = 'text-to-image-synthesis' | ||||
| multi_modal_embedding = 'multi-modal-embedding' | multi_modal_embedding = 'multi-modal-embedding' | ||||
| visual_question_answering = 'visual-question-answering' | |||||
| class InputFields(object): | class InputFields(object): | ||||
| @@ -1 +1 @@ | |||||
| https://alinlp.alibaba-inc.com/pypi/sofa-1.0.3-py3-none-any.whl | |||||
| https://alinlp.alibaba-inc.com/pypi/sofa-1.0.4.1-py3-none-any.whl | |||||
| @@ -0,0 +1,60 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models import Model | |||||
| from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering | |||||
| from modelscope.pipelines import VisualQuestionAnsweringPipeline, pipeline | |||||
| from modelscope.preprocessors import MPlugVisualQuestionAnsweringPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class VisualQuestionAnsweringTest(unittest.TestCase): | |||||
| model_id = 'damo/mplug_visual-question-answering_coco_large_en' | |||||
| input_vqa = { | |||||
| 'image': 'data/test/images/image_mplug_vqa.jpg', | |||||
| 'question': 'What is the woman doing?', | |||||
| } | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run(self): | |||||
| cache_path = snapshot_download(self.model_id) | |||||
| preprocessor = MPlugVisualQuestionAnsweringPreprocessor(cache_path) | |||||
| model = MPlugForVisualQuestionAnswering(cache_path) | |||||
| pipeline1 = VisualQuestionAnsweringPipeline( | |||||
| model, preprocessor=preprocessor) | |||||
| pipeline2 = pipeline( | |||||
| Tasks.visual_question_answering, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| print(f"question: {self.input_vqa['question']}") | |||||
| print(f"pipeline1: {pipeline1(self.input_vqa)['answer']}") | |||||
| print(f"pipeline2: {pipeline2(self.input_vqa)['answer']}") | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| preprocessor = MPlugVisualQuestionAnsweringPreprocessor( | |||||
| model.model_dir) | |||||
| pipeline_vqa = pipeline( | |||||
| task=Tasks.visual_question_answering, | |||||
| model=model, | |||||
| preprocessor=preprocessor) | |||||
| print(pipeline_vqa(self.input_vqa)) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_model_name(self): | |||||
| pipeline_vqa = pipeline( | |||||
| Tasks.visual_question_answering, model=self.model_id) | |||||
| print(pipeline_vqa(self.input_vqa)) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| pipeline_vqa = pipeline(task=Tasks.visual_question_answering) | |||||
| print(pipeline_vqa(self.input_vqa)) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||