From ac1ba2a0e05cd2fb48ddcbe383b7bea1a418305c Mon Sep 17 00:00:00 2001 From: "shouzhou.bx" Date: Tue, 9 Aug 2022 18:07:09 +0800 Subject: [PATCH] [to #42322933]bugfix : add PIL image type support and model.to(devices) for body_2d_keypoints ipeline Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9684583 --- modelscope/outputs.py | 18 ++++++------- .../cv/body_2d_keypoints_pipeline.py | 25 ++++++++++--------- tests/pipelines/test_body_2d_keypoints.py | 15 ++++++++--- 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/modelscope/outputs.py b/modelscope/outputs.py index a288a4c3..47799e04 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -161,19 +161,19 @@ TASK_OUTPUTS = { # human body keypoints detection result for single sample # { # "poses": [ - # [x, y], - # [x, y], - # [x, y] + # [[x, y]*15], + # [[x, y]*15], + # [[x, y]*15] # ] # "scores": [ - # [score], - # [score], - # [score], + # [[score]*15], + # [[score]*15], + # [[score]*15] # ] # "boxes": [ - # [x1, y1, x2, y2], - # [x1, y1, x2, y2], - # [x1, y1, x2, y2], + # [[x1, y1], [x2, y2]], + # [[x1, y1], [x2, y2]], + # [[x1, y1], [x2, y2]], # ] # } Tasks.body_2d_keypoints: diff --git a/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py b/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py index 887b53c7..f9ae4b2c 100644 --- a/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py +++ b/modelscope/pipelines/cv/body_2d_keypoints_pipeline.py @@ -16,7 +16,7 @@ from modelscope.outputs import OutputKeys from modelscope.pipelines import pipeline from modelscope.pipelines.base import Input, Model, Pipeline, Tensor from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import load_image +from modelscope.preprocessors import LoadImage from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger @@ -29,8 +29,9 @@ class Body2DKeypointsPipeline(Pipeline): def __init__(self, model: str, **kwargs): super().__init__(model=model, **kwargs) - self.keypoint_model = KeypointsDetection(model) - self.keypoint_model.eval() + device = torch.device( + f'cuda:{0}' if torch.cuda.is_available() else 'cpu') + self.keypoint_model = KeypointsDetection(model, device) self.human_detect_model_id = 'damo/cv_resnet18_human-detection' self.human_detector = pipeline( @@ -39,12 +40,8 @@ class Body2DKeypointsPipeline(Pipeline): def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: output = self.human_detector(input) - if isinstance(input, str): - image = cv2.imread(input, -1)[:, :, 0:3] - elif isinstance(input, np.ndarray): - if len(input.shape) == 2: - image = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) - image = image[:, :, 0:3] + image = LoadImage.convert_to_ndarray(input) + image = image[:, :, [2, 1, 0]] # rgb2bgr return {'image': image, 'output': output} @@ -88,14 +85,18 @@ class Body2DKeypointsPipeline(Pipeline): class KeypointsDetection(): - def __init__(self, model: str, **kwargs): + def __init__(self, model: str, device: str, **kwargs): self.model = model + self.device = device cfg = cfg_128x128_15 self.key_points_model = PoseHighResolutionNetV2(cfg) pretrained_state_dict = torch.load( - osp.join(self.model, ModelFile.TORCH_MODEL_FILE)) + osp.join(self.model, ModelFile.TORCH_MODEL_FILE), + map_location=device) self.key_points_model.load_state_dict( pretrained_state_dict, strict=False) + self.key_points_model = self.key_points_model.to(device) + self.key_points_model.eval() self.input_size = cfg['MODEL']['IMAGE_SIZE'] self.lst_parent_ids = cfg['DATASET']['PARENT_IDS'] @@ -111,7 +112,7 @@ class KeypointsDetection(): def forward(self, input: Tensor) -> Tensor: with torch.no_grad(): - return self.key_points_model.forward(input) + return self.key_points_model.forward(input.to(self.device)) def get_pts(self, heatmaps): [pts_num, height, width] = heatmaps.shape diff --git a/tests/pipelines/test_body_2d_keypoints.py b/tests/pipelines/test_body_2d_keypoints.py index e22925a6..eca5e961 100644 --- a/tests/pipelines/test_body_2d_keypoints.py +++ b/tests/pipelines/test_body_2d_keypoints.py @@ -3,6 +3,7 @@ import unittest import cv2 import numpy as np +from PIL import Image from modelscope.outputs import OutputKeys from modelscope.pipelines import pipeline @@ -68,8 +69,8 @@ class Body2DKeypointsTest(unittest.TestCase): self.model_id = 'damo/cv_hrnetv2w32_body-2d-keypoints_image' self.test_image = 'data/test/images/keypoints_detect/000000438862.jpg' - def pipeline_inference(self, pipeline: Pipeline): - output = pipeline(self.test_image) + def pipeline_inference(self, pipeline: Pipeline, pipeline_input): + output = pipeline(pipeline_input) poses = np.array(output[OutputKeys.POSES]) scores = np.array(output[OutputKeys.SCORES]) boxes = np.array(output[OutputKeys.BOXES]) @@ -80,11 +81,17 @@ class Body2DKeypointsTest(unittest.TestCase): draw_joints(image, np.array(poses[i]), np.array(scores[i])) cv2.imwrite('pose_keypoint.jpg', image) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_modelhub_with_image_file(self): + body_2d_keypoints = pipeline( + Tasks.body_2d_keypoints, model=self.model_id) + self.pipeline_inference(body_2d_keypoints, self.test_image) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_modelhub(self): + def test_run_modelhub_with_image_input(self): body_2d_keypoints = pipeline( Tasks.body_2d_keypoints, model=self.model_id) - self.pipeline_inference(body_2d_keypoints) + self.pipeline_inference(body_2d_keypoints, Image.open(self.test_image)) if __name__ == '__main__':