Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9653099master
| @@ -27,11 +27,14 @@ logger = get_logger() | |||||
| Tasks.body_2d_keypoints, module_name=Pipelines.body_2d_keypoints) | Tasks.body_2d_keypoints, module_name=Pipelines.body_2d_keypoints) | ||||
| class Body2DKeypointsPipeline(Pipeline): | class Body2DKeypointsPipeline(Pipeline): | ||||
| def __init__(self, model: str, human_detector: Pipeline, **kwargs): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| super().__init__(model=model, **kwargs) | super().__init__(model=model, **kwargs) | ||||
| self.keypoint_model = KeypointsDetection(model) | self.keypoint_model = KeypointsDetection(model) | ||||
| self.keypoint_model.eval() | self.keypoint_model.eval() | ||||
| self.human_detector = human_detector | |||||
| self.human_detect_model_id = 'damo/cv_resnet18_human-detection' | |||||
| self.human_detector = pipeline( | |||||
| Tasks.human_detection, model=self.human_detect_model_id) | |||||
| def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: | def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: | ||||
| output = self.human_detector(input) | output = self.human_detector(input) | ||||
| @@ -71,7 +71,6 @@ class Body2DKeypointsTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | def setUp(self) -> None: | ||||
| self.model_id = 'damo/cv_hrnetv2w32_body-2d-keypoints_image' | self.model_id = 'damo/cv_hrnetv2w32_body-2d-keypoints_image' | ||||
| self.test_image = 'data/test/images/keypoints_detect/000000438862.jpg' | self.test_image = 'data/test/images/keypoints_detect/000000438862.jpg' | ||||
| self.human_detect_model_id = 'damo/cv_resnet18_human-detection' | |||||
| def pipeline_inference(self, pipeline: Pipeline): | def pipeline_inference(self, pipeline: Pipeline): | ||||
| output = pipeline(self.test_image) | output = pipeline(self.test_image) | ||||
| @@ -87,12 +86,8 @@ class Body2DKeypointsTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_modelhub(self): | def test_run_modelhub(self): | ||||
| human_detector = pipeline( | |||||
| Tasks.human_detection, model=self.human_detect_model_id) | |||||
| body_2d_keypoints = pipeline( | body_2d_keypoints = pipeline( | ||||
| Tasks.body_2d_keypoints, | |||||
| human_detector=human_detector, | |||||
| model=self.model_id) | |||||
| Tasks.body_2d_keypoints, model=self.model_id) | |||||
| self.pipeline_inference(body_2d_keypoints) | self.pipeline_inference(body_2d_keypoints) | ||||