diff --git a/modelscope/models/multi_modal/mplug_for_all_tasks.py b/modelscope/models/multi_modal/mplug_for_all_tasks.py index a06e5800..d61fea10 100644 --- a/modelscope/models/multi_modal/mplug_for_all_tasks.py +++ b/modelscope/models/multi_modal/mplug_for_all_tasks.py @@ -57,18 +57,14 @@ class MPlugForAllTasks(TorchModel): if task == Tasks.image_text_retrieval: return {OutputKeys.SCORES: output[0].tolist()} topk_ids, _ = output - topk_ids = [topk_ids[i][0] for i in range(len(topk_ids))] - pred_strings: List[str] = \ - self.tokenizer.batch_decode(topk_ids, skip_special_tokens=True) - output = [] - for pred_string in pred_strings: - for _old, _new in replace_tokens_bert: - pred_string = pred_string.replace(_old, _new) - pred_string = pred_string.strip() - output.append(pred_string) + pred_string: List[str] = \ + self.tokenizer.decode(topk_ids[0][0]) + for _old, _new in replace_tokens_bert: + pred_string = pred_string.replace(_old, _new) + pred_string = pred_string.strip() output_key = OutputKeys.CAPTION \ if task == Tasks.image_captioning else OutputKeys.TEXT - return {output_key: output} + return {output_key: pred_string} # train and evaluate import addict