|
|
|
@@ -27,11 +27,14 @@ logger = get_logger() |
|
|
|
Tasks.body_2d_keypoints, module_name=Pipelines.body_2d_keypoints) |
|
|
|
class Body2DKeypointsPipeline(Pipeline): |
|
|
|
|
|
|
|
def __init__(self, model: str, human_detector: Pipeline, **kwargs): |
|
|
|
def __init__(self, model: str, **kwargs): |
|
|
|
super().__init__(model=model, **kwargs) |
|
|
|
self.keypoint_model = KeypointsDetection(model) |
|
|
|
self.keypoint_model.eval() |
|
|
|
self.human_detector = human_detector |
|
|
|
|
|
|
|
self.human_detect_model_id = 'damo/cv_resnet18_human-detection' |
|
|
|
self.human_detector = pipeline( |
|
|
|
Tasks.human_detection, model=self.human_detect_model_id) |
|
|
|
|
|
|
|
def preprocess(self, input: Input) -> Dict[Tensor, Union[str, np.ndarray]]: |
|
|
|
output = self.human_detector(input) |
|
|
|
|