From 42be514bac5f985d6d8ce710646e2a79e3d81d39 Mon Sep 17 00:00:00 2001 From: ly261666 Date: Wed, 12 Oct 2022 15:17:11 +0800 Subject: [PATCH] [to #42322933]update fer to satisfy demo service requirements Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10372291 --- .../pipelines/cv/facial_expression_recognition_pipeline.py | 6 +----- modelscope/utils/cv/image_utils.py | 4 +++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/modelscope/pipelines/cv/facial_expression_recognition_pipeline.py b/modelscope/pipelines/cv/facial_expression_recognition_pipeline.py index b598a457..3c85ae62 100644 --- a/modelscope/pipelines/cv/facial_expression_recognition_pipeline.py +++ b/modelscope/pipelines/cv/facial_expression_recognition_pipeline.py @@ -122,11 +122,7 @@ class FacialExpressionRecognitionPipeline(Pipeline): result = self.fer(input) assert result is not None scores = result[0].tolist() - labels = result[1].tolist() - return { - OutputKeys.SCORES: scores, - OutputKeys.LABELS: self.map_list[labels] - } + return {OutputKeys.SCORES: scores, OutputKeys.LABELS: self.map_list} def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: return inputs diff --git a/modelscope/utils/cv/image_utils.py b/modelscope/utils/cv/image_utils.py index ad0d6c8e..eab74688 100644 --- a/modelscope/utils/cv/image_utils.py +++ b/modelscope/utils/cv/image_utils.py @@ -113,7 +113,9 @@ def draw_face_detection_no_lm_result(img_path, detection_result): def draw_facial_expression_result(img_path, facial_expression_result): - label = facial_expression_result[OutputKeys.LABELS] + scores = facial_expression_result[OutputKeys.SCORES] + labels = facial_expression_result[OutputKeys.LABELS] + label = labels[np.argmax(scores)] img = cv2.imread(img_path) assert img is not None, f"Can't read img: {img_path}" cv2.putText(