From 7720ae50e241ed3a5cf319d9410b774228d8126c Mon Sep 17 00:00:00 2001 From: "jiangnana.jnn" Date: Mon, 17 Oct 2022 20:30:42 +0800 Subject: [PATCH] return dict values when input single sample for easycv pipeline Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10423383 --- .../pipelines/cv/easycv_pipelines/base.py | 18 +++++++++++++++++- .../cv/easycv_pipelines/detection_pipeline.py | 3 +++ .../face_2d_keypoints_pipeline.py | 3 +++ .../human_wholebody_keypoint_pipeline.py | 3 +++ tests/pipelines/test_face_2d_keypoints.py | 2 +- tests/pipelines/test_hand_2d_keypoints.py | 9 ++------- .../pipelines/test_human_wholebody_keypoint.py | 2 +- tests/pipelines/test_object_detection.py | 2 +- 8 files changed, 31 insertions(+), 11 deletions(-) diff --git a/modelscope/pipelines/cv/easycv_pipelines/base.py b/modelscope/pipelines/cv/easycv_pipelines/base.py index 8aea1146..c130aea0 100644 --- a/modelscope/pipelines/cv/easycv_pipelines/base.py +++ b/modelscope/pipelines/cv/easycv_pipelines/base.py @@ -4,7 +4,9 @@ import os import os.path as osp from typing import Any +import numpy as np from easycv.utils.ms_utils import EasyCVMeta +from PIL import ImageFile from modelscope.hub.snapshot_download import snapshot_download from modelscope.pipelines.util import is_official_hub_path @@ -94,5 +96,19 @@ class EasyCVPipeline(object): return easycv_config + def _is_single_inputs(self, inputs): + if isinstance(inputs, str) or (isinstance(inputs, list) + and len(inputs) == 1) or isinstance( + inputs, np.ndarray) or isinstance( + inputs, ImageFile.ImageFile): + return True + + return False + def __call__(self, inputs) -> Any: - return self.predict_op(inputs) + outputs = self.predict_op(inputs) + + if self._is_single_inputs(inputs): + outputs = outputs[0] + + return outputs diff --git a/modelscope/pipelines/cv/easycv_pipelines/detection_pipeline.py b/modelscope/pipelines/cv/easycv_pipelines/detection_pipeline.py index 0c2058d5..a1173bc4 100644 --- a/modelscope/pipelines/cv/easycv_pipelines/detection_pipeline.py +++ b/modelscope/pipelines/cv/easycv_pipelines/detection_pipeline.py @@ -57,4 +57,7 @@ class EasyCVDetectionPipeline(EasyCVPipeline): OutputKeys.BOXES: boxes } for output in outputs] + if self._is_single_inputs(inputs): + results = results[0] + return results diff --git a/modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py b/modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py index 7c32e0fc..b48d013e 100644 --- a/modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py +++ b/modelscope/pipelines/cv/easycv_pipelines/face_2d_keypoints_pipeline.py @@ -40,4 +40,7 @@ class Face2DKeypointsPipeline(EasyCVPipeline): OutputKeys.POSES: output['pose'] } for output in outputs] + if self._is_single_inputs(inputs): + results = results[0] + return results diff --git a/modelscope/pipelines/cv/easycv_pipelines/human_wholebody_keypoint_pipeline.py b/modelscope/pipelines/cv/easycv_pipelines/human_wholebody_keypoint_pipeline.py index 263f8225..936accbf 100644 --- a/modelscope/pipelines/cv/easycv_pipelines/human_wholebody_keypoint_pipeline.py +++ b/modelscope/pipelines/cv/easycv_pipelines/human_wholebody_keypoint_pipeline.py @@ -62,4 +62,7 @@ class HumanWholebodyKeypointsPipeline(EasyCVPipeline): OutputKeys.BOXES: output['boxes'] } for output in outputs] + if self._is_single_inputs(inputs): + results = results[0] + return results diff --git a/tests/pipelines/test_face_2d_keypoints.py b/tests/pipelines/test_face_2d_keypoints.py index 667ecddc..a5e347e8 100644 --- a/tests/pipelines/test_face_2d_keypoints.py +++ b/tests/pipelines/test_face_2d_keypoints.py @@ -18,7 +18,7 @@ class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase): face_2d_keypoints_align = pipeline( task=Tasks.face_2d_keypoints, model=model_id) - output = face_2d_keypoints_align(img_path)[0] + output = face_2d_keypoints_align(img_path) output_keypoints = output[OutputKeys.KEYPOINTS] output_pose = output[OutputKeys.POSES] diff --git a/tests/pipelines/test_hand_2d_keypoints.py b/tests/pipelines/test_hand_2d_keypoints.py index 86cd2d06..43b569d0 100644 --- a/tests/pipelines/test_hand_2d_keypoints.py +++ b/tests/pipelines/test_hand_2d_keypoints.py @@ -15,10 +15,8 @@ class Hand2DKeypointsPipelineTest(unittest.TestCase): model_id = 'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody' hand_keypoint = pipeline(task=Tasks.hand_2d_keypoints, model=model_id) - outputs = hand_keypoint(img_path) - self.assertEqual(len(outputs), 1) + results = hand_keypoint(img_path) - results = outputs[0] self.assertIn(OutputKeys.KEYPOINTS, results.keys()) self.assertIn(OutputKeys.BOXES, results.keys()) self.assertEqual(results[OutputKeys.KEYPOINTS].shape[1], 21) @@ -30,10 +28,7 @@ class Hand2DKeypointsPipelineTest(unittest.TestCase): img_path = 'data/test/images/hand_keypoints.jpg' hand_keypoint = pipeline(task=Tasks.hand_2d_keypoints) - outputs = hand_keypoint(img_path) - self.assertEqual(len(outputs), 1) - - results = outputs[0] + results = hand_keypoint(img_path) self.assertIn(OutputKeys.KEYPOINTS, results.keys()) self.assertIn(OutputKeys.BOXES, results.keys()) self.assertEqual(results[OutputKeys.KEYPOINTS].shape[1], 21) diff --git a/tests/pipelines/test_human_wholebody_keypoint.py b/tests/pipelines/test_human_wholebody_keypoint.py index b214f4e1..7c5946cc 100644 --- a/tests/pipelines/test_human_wholebody_keypoint.py +++ b/tests/pipelines/test_human_wholebody_keypoint.py @@ -18,7 +18,7 @@ class EasyCVFace2DKeypointsPipelineTest(unittest.TestCase): human_wholebody_keypoint_pipeline = pipeline( task=Tasks.human_wholebody_keypoint, model=model_id) - output = human_wholebody_keypoint_pipeline(img_path)[0] + output = human_wholebody_keypoint_pipeline(img_path) output_keypoints = output[OutputKeys.KEYPOINTS] output_pose = output[OutputKeys.BOXES] diff --git a/tests/pipelines/test_object_detection.py b/tests/pipelines/test_object_detection.py index 00a71371..64766c77 100644 --- a/tests/pipelines/test_object_detection.py +++ b/tests/pipelines/test_object_detection.py @@ -55,7 +55,7 @@ class ObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck): image_object_detection_auto = pipeline( Tasks.image_object_detection, model=model_id) - result = image_object_detection_auto(test_image)[0] + result = image_object_detection_auto(test_image) image_object_detection_auto.show_result(test_image, result, 'auto_demo_ret.jpg')