Browse Source

[to #42322933] Add MPLUG model

添加 MPLUG 模型的 visual question answering 任务 pipeline

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9182119
master
hemu.zp 3 years ago
parent
commit
fabea5604e
12 changed files with 232 additions and 3 deletions
  1. +3
    -0
      data/test/images/image_mplug_vqa.jpg
  2. +3
    -0
      modelscope/metainfo.py
  3. +2
    -0
      modelscope/models/multi_modal/__init__.py
  4. +46
    -0
      modelscope/models/multi_modal/mplug_for_visual_question_answering.py
  5. +4
    -1
      modelscope/pipelines/builder.py
  6. +1
    -0
      modelscope/pipelines/multi_modal/__init__.py
  7. +65
    -0
      modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py
  8. +1
    -1
      modelscope/preprocessors/__init__.py
  9. +45
    -0
      modelscope/preprocessors/multi_modal.py
  10. +1
    -0
      modelscope/utils/constant.py
  11. +1
    -1
      requirements/nlp.txt
  12. +60
    -0
      tests/pipelines/test_visual_question_answering.py

+ 3
- 0
data/test/images/image_mplug_vqa.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b37b706885849037b5fa7fa44a3b78a6375f768d95ce46bfcb8e7329d038a692
size 181725

+ 3
- 0
modelscope/metainfo.py View File

@@ -27,6 +27,7 @@ class Models(object):
# multi-modal models
ofa = 'ofa'
clip = 'clip-multi-modal-embedding'
mplug = 'mplug'


class Pipelines(object):
@@ -63,6 +64,7 @@ class Pipelines(object):
# multi-modal tasks
image_caption = 'image-caption'
multi_modal_embedding = 'multi-modal-embedding'
visual_question_answering = 'visual-question-answering'


class Trainers(object):
@@ -105,3 +107,4 @@ class Preprocessors(object):

# multi-modal
ofa_image_caption = 'ofa-image-caption'
mplug_visual_question_answering = 'mplug-visual-question-answering'

+ 2
- 0
modelscope/models/multi_modal/__init__.py View File

@@ -1,2 +1,4 @@
from .clip.clip_model import CLIPForMultiModalEmbedding
from .image_captioning_model import OfaForImageCaptioning
from .mplug_for_visual_question_answering import \
MPlugForVisualQuestionAnswering

+ 46
- 0
modelscope/models/multi_modal/mplug_for_visual_question_answering.py View File

@@ -0,0 +1,46 @@
from typing import Dict

from ...metainfo import Models
from ...utils.constant import Tasks
from ..base import Model, Tensor
from ..builder import MODELS

__all__ = ['MPlugForVisualQuestionAnswering']


@MODELS.register_module(
Tasks.visual_question_answering, module_name=Models.mplug)
class MPlugForVisualQuestionAnswering(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the mplug model from the `model_dir` path.
Args:
model_dir (str): the model path.
"""

super().__init__(model_dir, *args, **kwargs)
from sofa.models.mplug import MPlugForVisualQuestionAnswering
self.model = MPlugForVisualQuestionAnswering.from_pretrained(model_dir)
self.tokenizer = self.model.tokenizer

def train(self):
return self.model.train()

def eval(self):
return self.model.eval()

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model

Args:
input (Dict[str, Tensor]): the preprocessed data

Returns:
Dict[str, Tensor]: results
Example:
{
'predictions': Tensor([[1377, 4959, 2785, 6392...])]),
}
"""

return self.model(**input)[0]

+ 4
- 1
modelscope/pipelines/builder.py View File

@@ -42,7 +42,10 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_TAdaConv_action-recognition'),
Tasks.multi_modal_embedding:
(Pipelines.multi_modal_embedding,
'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding')
'damo/multi-modal_clip-vit-large-patch14-chinese_multi-modal-embedding'),
Tasks.visual_question_answering:
(Pipelines.visual_question_answering,
'damo/mplug_visual-question-answering_coco_large_en'),
}




+ 1
- 0
modelscope/pipelines/multi_modal/__init__.py View File

@@ -1,2 +1,3 @@
from .image_captioning_pipeline import ImageCaptionPipeline
from .multi_modal_embedding_pipeline import MultiModalEmbeddingPipeline
from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline

+ 65
- 0
modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py View File

@@ -0,0 +1,65 @@
from typing import Any, Dict, Optional, Union

import torch

from ...metainfo import Pipelines
from ...models import Model
from ...models.multi_modal import MPlugForVisualQuestionAnswering
from ...preprocessors import MPlugVisualQuestionAnsweringPreprocessor
from ...utils.constant import Tasks
from ..base import Pipeline, Tensor
from ..builder import PIPELINES

__all__ = ['VisualQuestionAnsweringPipeline']


@PIPELINES.register_module(
Tasks.visual_question_answering,
module_name=Pipelines.visual_question_answering)
class VisualQuestionAnsweringPipeline(Pipeline):

def __init__(self,
model: Union[MPlugForVisualQuestionAnswering, str],
preprocessor: Optional[
MPlugVisualQuestionAnsweringPreprocessor] = None,
**kwargs):
"""use `model` and `preprocessor` to create a visual question answering pipeline for prediction

Args:
model (MPlugForVisualQuestionAnswering): a model instance
preprocessor (MPlugVisualQuestionAnsweringPreprocessor): a preprocessor instance
"""
model = model if isinstance(
model,
MPlugForVisualQuestionAnswering) else Model.from_pretrained(model)
if preprocessor is None:
preprocessor = MPlugVisualQuestionAnsweringPreprocessor(
model.model_dir)
model.eval()
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.tokenizer = model.tokenizer

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
return super().forward(inputs, **forward_params)

def postprocess(self, inputs: Dict[str, Tensor],
**postprocess_params) -> Dict[str, str]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the prediction results
"""
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''),
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''),
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', ''))

pred_string = self.tokenizer.decode(inputs[0][0])
for _old, _new in replace_tokens_bert:
pred_string = pred_string.replace(_old, _new)
pred_string.strip()
return {'answer': pred_string}

+ 1
- 1
modelscope/preprocessors/__init__.py View File

@@ -6,6 +6,6 @@ from .builder import PREPROCESSORS, build_preprocessor
from .common import Compose
from .image import LoadImage, load_image
from .kws import WavToLists
from .multi_modal import OfaImageCaptionPreprocessor
from .multi_modal import * # noqa F403
from .nlp import * # noqa F403
from .text_to_speech import * # noqa F403

+ 45
- 0
modelscope/preprocessors/multi_modal.py View File

@@ -16,6 +16,7 @@ from .image import load_image

__all__ = [
'OfaImageCaptionPreprocessor',
'MPlugVisualQuestionAnsweringPreprocessor',
]


@@ -110,3 +111,47 @@ class OfaImageCaptionPreprocessor(Preprocessor):
}
}
return sample


@PREPROCESSORS.register_module(
Fields.multi_modal,
module_name=Preprocessors.mplug_visual_question_answering)
class MPlugVisualQuestionAnsweringPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via 'bert-base-uncased' tokenizer and configuration

"""
super().__init__(*args, **kwargs)

# tokenizer
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# load configuration
from sofa.models.mplug import CONFIG_NAME, MPlugConfig
config = MPlugConfig.from_yaml_file(osp.join(model_dir, CONFIG_NAME))

# Initialize transform
from torchvision import transforms
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)

self.patch_resize_transform = transforms.Compose([
transforms.Resize((config.image_res, config.image_res),
interpolation=Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image, question = data['image'], data['question']
image = Image.open(image).convert('RGB') if isinstance(image,
str) else image
image = self.patch_resize_transform(image)
image = torch.stack([image], dim=0)
question = self.tokenizer([question.lower()],
padding='longest',
return_tensors='pt')

return {'image': image, 'question': question, 'train': False}

+ 1
- 0
modelscope/utils/constant.py View File

@@ -61,6 +61,7 @@ class Tasks(object):
visual_grounding = 'visual-grounding'
text_to_image_synthesis = 'text-to-image-synthesis'
multi_modal_embedding = 'multi-modal-embedding'
visual_question_answering = 'visual-question-answering'


class InputFields(object):


+ 1
- 1
requirements/nlp.txt View File

@@ -1 +1 @@
https://alinlp.alibaba-inc.com/pypi/sofa-1.0.3-py3-none-any.whl
https://alinlp.alibaba-inc.com/pypi/sofa-1.0.4.1-py3-none-any.whl

+ 60
- 0
tests/pipelines/test_visual_question_answering.py View File

@@ -0,0 +1,60 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.models.multi_modal import MPlugForVisualQuestionAnswering
from modelscope.pipelines import VisualQuestionAnsweringPipeline, pipeline
from modelscope.preprocessors import MPlugVisualQuestionAnsweringPreprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class VisualQuestionAnsweringTest(unittest.TestCase):
model_id = 'damo/mplug_visual-question-answering_coco_large_en'
input_vqa = {
'image': 'data/test/images/image_mplug_vqa.jpg',
'question': 'What is the woman doing?',
}

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):
cache_path = snapshot_download(self.model_id)
preprocessor = MPlugVisualQuestionAnsweringPreprocessor(cache_path)
model = MPlugForVisualQuestionAnswering(cache_path)
pipeline1 = VisualQuestionAnsweringPipeline(
model, preprocessor=preprocessor)
pipeline2 = pipeline(
Tasks.visual_question_answering,
model=model,
preprocessor=preprocessor)
print(f"question: {self.input_vqa['question']}")
print(f"pipeline1: {pipeline1(self.input_vqa)['answer']}")
print(f"pipeline2: {pipeline2(self.input_vqa)['answer']}")

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
preprocessor = MPlugVisualQuestionAnsweringPreprocessor(
model.model_dir)
pipeline_vqa = pipeline(
task=Tasks.visual_question_answering,
model=model,
preprocessor=preprocessor)
print(pipeline_vqa(self.input_vqa))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_vqa = pipeline(
Tasks.visual_question_answering, model=self.model_id)
print(pipeline_vqa(self.input_vqa))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_vqa = pipeline(task=Tasks.visual_question_answering)
print(pipeline_vqa(self.input_vqa))


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save