diff --git a/modelscope/models/multi_modal/mplug/modeling_mplug.py b/modelscope/models/multi_modal/mplug/modeling_mplug.py index f469c218..ec491f1d 100755 --- a/modelscope/models/multi_modal/mplug/modeling_mplug.py +++ b/modelscope/models/multi_modal/mplug/modeling_mplug.py @@ -1868,6 +1868,8 @@ class MPlug(PreTrainedModel): checkpoint = torch.load(checkpoint_path, map_location='cpu') if 'model' in checkpoint: checkpoint = checkpoint['model'] + if 'module' in checkpoint: + checkpoint = checkpoint['module'] checkpoint = { k.replace('model.', ''): v for k, v in checkpoint.items() diff --git a/modelscope/models/multi_modal/mplug_for_all_tasks.py b/modelscope/models/multi_modal/mplug_for_all_tasks.py index 608cc733..a06e5800 100644 --- a/modelscope/models/multi_modal/mplug_for_all_tasks.py +++ b/modelscope/models/multi_modal/mplug_for_all_tasks.py @@ -1,10 +1,13 @@ +import os.path as osp from typing import Dict, List from modelscope.metainfo import Models from modelscope.models import TorchModel from modelscope.models.base import Tensor from modelscope.models.builder import MODELS -from modelscope.utils.constant import Tasks +from modelscope.outputs import OutputKeys +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks __all__ = ['MPlugForAllTasks'] @@ -44,17 +47,28 @@ class MPlugForAllTasks(TorchModel): ('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), ('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) + # get task from config file + task = Config.from_file( + osp.join(self.model_dir, ModelFile.CONFIGURATION)).task + # inference if not self.training and 'question' in input: output = self.model(input['image'], input['question'], train=False) - if not isinstance(output, tuple): - return output + if task == Tasks.image_text_retrieval: + return {OutputKeys.SCORES: output[0].tolist()} topk_ids, _ = output - pred_string: str = self.tokenizer.decode(topk_ids[0][0]) - for _old, _new in replace_tokens_bert: - pred_string = pred_string.replace(_old, _new) - pred_string = pred_string.strip() - return pred_string + topk_ids = [topk_ids[i][0] for i in range(len(topk_ids))] + pred_strings: List[str] = \ + self.tokenizer.batch_decode(topk_ids, skip_special_tokens=True) + output = [] + for pred_string in pred_strings: + for _old, _new in replace_tokens_bert: + pred_string = pred_string.replace(_old, _new) + pred_string = pred_string.strip() + output.append(pred_string) + output_key = OutputKeys.CAPTION \ + if task == Tasks.image_captioning else OutputKeys.TEXT + return {output_key: output} # train and evaluate import addict @@ -71,7 +85,7 @@ class MPlugForAllTasks(TorchModel): index = input['index'] output = self.model(image, answer, index, train=self.training) if self.training: - return {'loss': output} + return {OutputKeys.LOSS: output} # evaluate topk_ids, _ = output diff --git a/modelscope/pipelines/multi_modal/image_captioning_pipeline.py b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py index 99cccee1..81a5f8cd 100644 --- a/modelscope/pipelines/multi_modal/image_captioning_pipeline.py +++ b/modelscope/pipelines/multi_modal/image_captioning_pipeline.py @@ -52,6 +52,4 @@ class ImageCaptioningPipeline(Pipeline): return super().forward(inputs, **forward_params) def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - if isinstance(self.model, OfaForAllTasks): - return inputs - return {OutputKeys.CAPTION: inputs} + return inputs diff --git a/modelscope/pipelines/multi_modal/image_text_retrieval_pipeline.py b/modelscope/pipelines/multi_modal/image_text_retrieval_pipeline.py index 1ebcf526..329d79bf 100644 --- a/modelscope/pipelines/multi_modal/image_text_retrieval_pipeline.py +++ b/modelscope/pipelines/multi_modal/image_text_retrieval_pipeline.py @@ -48,4 +48,4 @@ class ImageTextRetrievalPipeline(Pipeline): return super().forward(inputs, **forward_params) def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - return {OutputKeys.SCORES: inputs[0].tolist()} + return inputs diff --git a/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py b/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py index b2442a3e..86177074 100644 --- a/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py +++ b/modelscope/pipelines/multi_modal/visual_question_answering_pipeline.py @@ -56,6 +56,4 @@ class VisualQuestionAnsweringPipeline(Pipeline): Returns: Dict[str, str]: the prediction results """ - if isinstance(self.model, OfaForAllTasks): - return inputs - return {OutputKeys.TEXT: inputs} + return inputs