From a4cfbaa0ddb448ed8e3b33917e220ec13f7bef5b Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Fri, 9 Sep 2022 14:56:05 +0800 Subject: [PATCH] [to #42322933] revert mplug batch inference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 由于之前实现的 batch 化 inference 与 pipelines/base.py 中输入 List[Input] 的情况存在冲突,移除了此处之前实现的 batch 化 inference 代码。mplug 模型在 pipeline 中推理时输入只接受 Image.Image,str,tuple,dict 类型,对于 List[Input] 的情况由 pipelines/base.py 中的代码进行处理 --- .../models/multi_modal/mplug_for_all_tasks.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/modelscope/models/multi_modal/mplug_for_all_tasks.py b/modelscope/models/multi_modal/mplug_for_all_tasks.py index a06e5800..d61fea10 100644 --- a/modelscope/models/multi_modal/mplug_for_all_tasks.py +++ b/modelscope/models/multi_modal/mplug_for_all_tasks.py @@ -57,18 +57,14 @@ class MPlugForAllTasks(TorchModel): if task == Tasks.image_text_retrieval: return {OutputKeys.SCORES: output[0].tolist()} topk_ids, _ = output - topk_ids = [topk_ids[i][0] for i in range(len(topk_ids))] - pred_strings: List[str] = \ - self.tokenizer.batch_decode(topk_ids, skip_special_tokens=True) - output = [] - for pred_string in pred_strings: - for _old, _new in replace_tokens_bert: - pred_string = pred_string.replace(_old, _new) - pred_string = pred_string.strip() - output.append(pred_string) + pred_string: List[str] = \ + self.tokenizer.decode(topk_ids[0][0]) + for _old, _new in replace_tokens_bert: + pred_string = pred_string.replace(_old, _new) + pred_string = pred_string.strip() output_key = OutputKeys.CAPTION \ if task == Tasks.image_captioning else OutputKeys.TEXT - return {output_key: output} + return {output_key: pred_string} # train and evaluate import addict