diff --git a/modelscope/pipelines/cv/image_classification_pipeline.py b/modelscope/pipelines/cv/image_classification_pipeline.py index 439ea6d3..49467eab 100644 --- a/modelscope/pipelines/cv/image_classification_pipeline.py +++ b/modelscope/pipelines/cv/image_classification_pipeline.py @@ -102,13 +102,17 @@ class GeneralImageClassificationPipeline(Pipeline): def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: scores = inputs['scores'] - pred_score = np.max(scores, axis=1)[0] - pred_label = np.argmax(scores, axis=1)[0] - result = {'pred_label': pred_label, 'pred_score': float(pred_score)} - result['pred_class'] = self.model.CLASSES[result['pred_label']] + + pred_scores = np.sort(scores, axis=1)[0][::-1][:5] + pred_labels = np.argsort(scores, axis=1)[0][::-1][:5] + + result = {'pred_score': [score for score in pred_scores]} + result['pred_class'] = [ + self.model.CLASSES[lable] for lable in pred_labels + ] outputs = { - OutputKeys.SCORES: [result['pred_score']], - OutputKeys.LABELS: [result['pred_class']] + OutputKeys.SCORES: result['pred_score'], + OutputKeys.LABELS: result['pred_class'] } return outputs