From 5cec664b5dd49db543c816d75f37e3d79553edc1 Mon Sep 17 00:00:00 2001 From: "xiachen.wyh" Date: Tue, 9 Aug 2022 17:27:11 +0800 Subject: [PATCH] [to #42322933]cv/tinynas/classification2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix to the output form to list 调整输出格式,改为 list 格式 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9684053 --- modelscope/pipelines/cv/tinynas_classification_pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modelscope/pipelines/cv/tinynas_classification_pipeline.py b/modelscope/pipelines/cv/tinynas_classification_pipeline.py index d49166d1..a470e58b 100644 --- a/modelscope/pipelines/cv/tinynas_classification_pipeline.py +++ b/modelscope/pipelines/cv/tinynas_classification_pipeline.py @@ -90,7 +90,7 @@ class TinynasClassificationPipeline(Pipeline): output_prob = torch.nn.functional.softmax(inputs['outputs'], dim=-1) score = torch.max(output_prob) output_dict = { - OutputKeys.SCORES: score.item(), - OutputKeys.LABELS: label_dict[inputs['outputs'].argmax().item()] + OutputKeys.SCORES: [score.item()], + OutputKeys.LABELS: [label_dict[inputs['outputs'].argmax().item()]] } return output_dict