Browse Source

[to #42322933] Fix mplug model interface

1. 修复 mplug 模型接口问题
2. 修复 mplug inference 不支持 batch 输入问题
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10052321
master
hemu.zp yingda.chen 3 years ago
parent
commit
be2f31fc15
5 changed files with 28 additions and 16 deletions
  1. +2
    -0
      modelscope/models/multi_modal/mplug/modeling_mplug.py
  2. +23
    -9
      modelscope/models/multi_modal/mplug_for_all_tasks.py
  3. +1
    -3
      modelscope/pipelines/multi_modal/image_captioning_pipeline.py
  4. +1
    -1
      modelscope/pipelines/multi_modal/image_text_retrieval_pipeline.py
  5. +1
    -3
      modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py

+ 2
- 0
modelscope/models/multi_modal/mplug/modeling_mplug.py View File

@@ -1868,6 +1868,8 @@ class MPlug(PreTrainedModel):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
if 'model' in checkpoint:
checkpoint = checkpoint['model']
if 'module' in checkpoint:
checkpoint = checkpoint['module']
checkpoint = {
k.replace('model.', ''): v
for k, v in checkpoint.items()


+ 23
- 9
modelscope/models/multi_modal/mplug_for_all_tasks.py View File

@@ -1,10 +1,13 @@
import os.path as osp
from typing import Dict, List

from modelscope.metainfo import Models
from modelscope.models import TorchModel
from modelscope.models.base import Tensor
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
from modelscope.outputs import OutputKeys
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks

__all__ = ['MPlugForAllTasks']

@@ -44,17 +47,28 @@ class MPlugForAllTasks(TorchModel):
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''),
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', ''))

# get task from config file
task = Config.from_file(
osp.join(self.model_dir, ModelFile.CONFIGURATION)).task

# inference
if not self.training and 'question' in input:
output = self.model(input['image'], input['question'], train=False)
if not isinstance(output, tuple):
return output
if task == Tasks.image_text_retrieval:
return {OutputKeys.SCORES: output[0].tolist()}
topk_ids, _ = output
pred_string: str = self.tokenizer.decode(topk_ids[0][0])
for _old, _new in replace_tokens_bert:
pred_string = pred_string.replace(_old, _new)
pred_string = pred_string.strip()
return pred_string
topk_ids = [topk_ids[i][0] for i in range(len(topk_ids))]
pred_strings: List[str] = \
self.tokenizer.batch_decode(topk_ids, skip_special_tokens=True)
output = []
for pred_string in pred_strings:
for _old, _new in replace_tokens_bert:
pred_string = pred_string.replace(_old, _new)
pred_string = pred_string.strip()
output.append(pred_string)
output_key = OutputKeys.CAPTION \
if task == Tasks.image_captioning else OutputKeys.TEXT
return {output_key: output}

# train and evaluate
import addict
@@ -71,7 +85,7 @@ class MPlugForAllTasks(TorchModel):
index = input['index']
output = self.model(image, answer, index, train=self.training)
if self.training:
return {'loss': output}
return {OutputKeys.LOSS: output}

# evaluate
topk_ids, _ = output


+ 1
- 3
modelscope/pipelines/multi_modal/image_captioning_pipeline.py View File

@@ -52,6 +52,4 @@ class ImageCaptioningPipeline(Pipeline):
return super().forward(inputs, **forward_params)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(self.model, OfaForAllTasks):
return inputs
return {OutputKeys.CAPTION: inputs}
return inputs

+ 1
- 1
modelscope/pipelines/multi_modal/image_text_retrieval_pipeline.py View File

@@ -48,4 +48,4 @@ class ImageTextRetrievalPipeline(Pipeline):
return super().forward(inputs, **forward_params)

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return {OutputKeys.SCORES: inputs[0].tolist()}
return inputs

+ 1
- 3
modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py View File

@@ -56,6 +56,4 @@ class VisualQuestionAnsweringPipeline(Pipeline):
Returns:
Dict[str, str]: the prediction results
"""
if isinstance(self.model, OfaForAllTasks):
return inputs
return {OutputKeys.TEXT: inputs}
return inputs

Loading…
Cancel
Save