|
|
|
@@ -1,10 +1,13 @@ |
|
|
|
import os.path as osp |
|
|
|
from typing import Dict, List |
|
|
|
|
|
|
|
from modelscope.metainfo import Models |
|
|
|
from modelscope.models import TorchModel |
|
|
|
from modelscope.models.base import Tensor |
|
|
|
from modelscope.models.builder import MODELS |
|
|
|
from modelscope.utils.constant import Tasks |
|
|
|
from modelscope.outputs import OutputKeys |
|
|
|
from modelscope.utils.config import Config |
|
|
|
from modelscope.utils.constant import ModelFile, Tasks |
|
|
|
|
|
|
|
__all__ = ['MPlugForAllTasks'] |
|
|
|
|
|
|
|
@@ -44,17 +47,28 @@ class MPlugForAllTasks(TorchModel): |
|
|
|
('[unused1]', ''), (r' +', ' '), ('[SEP]', ''), |
|
|
|
('[unused2]', ''), ('[CLS]', ''), ('[UNK]', '')) |
|
|
|
|
|
|
|
# get task from config file |
|
|
|
task = Config.from_file( |
|
|
|
osp.join(self.model_dir, ModelFile.CONFIGURATION)).task |
|
|
|
|
|
|
|
# inference |
|
|
|
if not self.training and 'question' in input: |
|
|
|
output = self.model(input['image'], input['question'], train=False) |
|
|
|
if not isinstance(output, tuple): |
|
|
|
return output |
|
|
|
if task == Tasks.image_text_retrieval: |
|
|
|
return {OutputKeys.SCORES: output[0].tolist()} |
|
|
|
topk_ids, _ = output |
|
|
|
pred_string: str = self.tokenizer.decode(topk_ids[0][0]) |
|
|
|
for _old, _new in replace_tokens_bert: |
|
|
|
pred_string = pred_string.replace(_old, _new) |
|
|
|
pred_string = pred_string.strip() |
|
|
|
return pred_string |
|
|
|
topk_ids = [topk_ids[i][0] for i in range(len(topk_ids))] |
|
|
|
pred_strings: List[str] = \ |
|
|
|
self.tokenizer.batch_decode(topk_ids, skip_special_tokens=True) |
|
|
|
output = [] |
|
|
|
for pred_string in pred_strings: |
|
|
|
for _old, _new in replace_tokens_bert: |
|
|
|
pred_string = pred_string.replace(_old, _new) |
|
|
|
pred_string = pred_string.strip() |
|
|
|
output.append(pred_string) |
|
|
|
output_key = OutputKeys.CAPTION \ |
|
|
|
if task == Tasks.image_captioning else OutputKeys.TEXT |
|
|
|
return {output_key: output} |
|
|
|
|
|
|
|
# train and evaluate |
|
|
|
import addict |
|
|
|
@@ -71,7 +85,7 @@ class MPlugForAllTasks(TorchModel): |
|
|
|
index = input['index'] |
|
|
|
output = self.model(image, answer, index, train=self.training) |
|
|
|
if self.training: |
|
|
|
return {'loss': output} |
|
|
|
return {OutputKeys.LOSS: output} |
|
|
|
|
|
|
|
# evaluate |
|
|
|
topk_ids, _ = output |
|
|
|
|