From 79d602a1dafb0801a4b98ad014a1bf65d7755f44 Mon Sep 17 00:00:00 2001 From: "shouzhou.bx" Date: Fri, 5 Aug 2022 12:33:46 +0800 Subject: [PATCH] =?UTF-8?q?[to=20#42322933]=20bugfix=EF=BC=9Amove=20input?= =?UTF-8?q?=20human=5Fdeteciton=20pipeline=20into=20body=5F2d=5Fdetection?= =?UTF-8?q?=20init=20=20=20=20=20=20=20=20=20Link:=20https://code.alibaba-?= =?UTF-8?q?inc.com/Ali-MaaS/MaaS-lib/codereview/9653099?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- modelscope/pipelines/cv/body_2d_keypoints_pipeline.py | 7 +++++-- tests/pipelines/test_body_2d_keypoints.py | 7 +------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py b/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py index f16c48e4..887b53c7 100644 --- a/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py +++ b/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py @@ -27,11 +27,14 @@ logger = get_logger() Tasks.body_2d_keypoints, module_name=Pipelines.body_2d_keypoints) class Body2DKeypointsPipeline(Pipeline): - def __init__(self, model: str, human_detector: Pipeline, **kwargs): + def __init__(self, model: str, **kwargs): super().__init__(model=model, **kwargs) self.keypoint_model = KeypointsDetection(model) self.keypoint_model.eval() - self.human_detector = human_detector + + self.human_detect_model_id = 'damo/cv_resnet18_human-detection' + self.human_detector = pipeline( + Tasks.human_detection, model=self.human_detect_model_id) def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: output = self.human_detector(input) diff --git a/tests/pipelines/test_body_2d_keypoints.py b/tests/pipelines/test_body_2d_keypoints.py index 9b5bcdee..3ff00926 100644 --- a/tests/pipelines/test_body_2d_keypoints.py +++ b/tests/pipelines/test_body_2d_keypoints.py @@ -71,7 +71,6 @@ class Body2DKeypointsTest(unittest.TestCase): def setUp(self) -> None: self.model_id = 'damo/cv_hrnetv2w32_body-2d-keypoints_image' self.test_image = 'data/test/images/keypoints_detect/000000438862.jpg' - self.human_detect_model_id = 'damo/cv_resnet18_human-detection' def pipeline_inference(self, pipeline: Pipeline): output = pipeline(self.test_image) @@ -87,12 +86,8 @@ class Body2DKeypointsTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_modelhub(self): - human_detector = pipeline( - Tasks.human_detection, model=self.human_detect_model_id) body_2d_keypoints = pipeline( - Tasks.body_2d_keypoints, - human_detector=human_detector, - model=self.model_id) + Tasks.body_2d_keypoints, model=self.model_id) self.pipeline_inference(body_2d_keypoints)