From 00b448a2feb2982973c6716ba793d81fb5f1ef59 Mon Sep 17 00:00:00 2001 From: "hanyuan.chy" Date: Sat, 27 Aug 2022 14:11:30 +0800 Subject: [PATCH] [to #42322933] support 3d body keypoints Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9862567 --- data/test/videos/Walking.54138969.mp4 | 3 + modelscope/metainfo.py | 2 + modelscope/models/cv/__init__.py | 11 +- .../models/cv/body_3d_keypoints/__init__.py | 23 ++ .../cv/body_3d_keypoints/body_3d_pose.py | 246 ++++++++++++++++++ .../canonical_pose_modules.py | 233 +++++++++++++++++ modelscope/outputs.py | 10 + modelscope/pipelines/builder.py | 2 + modelscope/pipelines/cv/__init__.py | 2 + .../cv/body_3d_keypoints_pipeline.py | 213 +++++++++++++++ modelscope/utils/constant.py | 1 + tests/pipelines/test_body_3d_keypoints.py | 49 ++++ 12 files changed, 792 insertions(+), 3 deletions(-) create mode 100644 data/test/videos/Walking.54138969.mp4 create mode 100644 modelscope/models/cv/body_3d_keypoints/__init__.py create mode 100644 modelscope/models/cv/body_3d_keypoints/body_3d_pose.py create mode 100644 modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py create mode 100644 modelscope/pipelines/cv/body_3d_keypoints_pipeline.py create mode 100644 tests/pipelines/test_body_3d_keypoints.py diff --git a/data/test/videos/Walking.54138969.mp4 b/data/test/videos/Walking.54138969.mp4 new file mode 100644 index 00000000..1716695f --- /dev/null +++ b/data/test/videos/Walking.54138969.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7b8f50a0537bfe7e082c5ad91b2b7ece61a0adbeb7489988e553909276bf920c +size 44217644 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 153ca9b4..d9e53ca7 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -20,6 +20,7 @@ class Models(object): gpen = 'gpen' product_retrieval_embedding = 'product-retrieval-embedding' body_2d_keypoints = 'body-2d-keypoints' + body_3d_keypoints = 'body-3d-keypoints' crowd_counting = 'HRNetCrowdCounting' panoptic_segmentation = 'swinL-panoptic-segmentation' image_reid_person = 'passvitb' @@ -95,6 +96,7 @@ class Pipelines(object): general_recognition = 'resnet101-general-recognition' cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image' + body_3d_keypoints = 'canonical_body-3d-keypoints_video' human_detection = 'resnet18-human-detection' object_detection = 'vit-object-detection' easycv_detection = 'easycv-detection' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index 74451c31..10040637 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -1,11 +1,16 @@ # Copyright (c) Alibaba, Inc. and its affiliates. + +# yapf: disable from . import (action_recognition, animal_recognition, body_2d_keypoints, - cartoon, cmdssl_video_embedding, crowd_counting, face_detection, - face_generation, image_classification, image_color_enhance, - image_colorization, image_denoise, image_instance_segmentation, + body_3d_keypoints, cartoon, cmdssl_video_embedding, + crowd_counting, face_detection, face_generation, + image_classification, image_color_enhance, image_colorization, + image_denoise, image_instance_segmentation, image_panoptic_segmentation, image_portrait_enhancement, image_reid_person, image_semantic_segmentation, image_to_image_generation, image_to_image_translation, object_detection, product_retrieval_embedding, realtime_object_detection, salient_detection, super_resolution, video_single_object_tracking, video_summarization, virual_tryon) + +# yapf: enable diff --git a/modelscope/models/cv/body_3d_keypoints/__init__.py b/modelscope/models/cv/body_3d_keypoints/__init__.py new file mode 100644 index 00000000..4bb83936 --- /dev/null +++ b/modelscope/models/cv/body_3d_keypoints/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .body_3d_pose import BodyKeypointsDetection3D + +else: + _import_structure = { + 'body_3d_pose': ['BodyKeypointsDetection3D'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/body_3d_keypoints/body_3d_pose.py b/modelscope/models/cv/body_3d_keypoints/body_3d_pose.py new file mode 100644 index 00000000..87cd4962 --- /dev/null +++ b/modelscope/models/cv/body_3d_keypoints/body_3d_pose.py @@ -0,0 +1,246 @@ +import logging +import os.path as osp +from typing import Any, Dict, List, Union + +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.body_3d_keypoints.canonical_pose_modules import ( + TemporalModel, TransCan3Dkeys) +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['BodyKeypointsDetection3D'] + + +class KeypointsTypes(object): + POSES_CAMERA = 'poses_camera' + POSES_TRAJ = 'poses_traj' + + +@MODELS.register_module( + Tasks.body_3d_keypoints, module_name=Models.body_3d_keypoints) +class BodyKeypointsDetection3D(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + + super().__init__(model_dir, *args, **kwargs) + + self.model_dir = model_dir + model_path = osp.join(self.model_dir, ModelFile.TORCH_MODEL_FILE) + cfg_path = osp.join(self.model_dir, ModelFile.CONFIGURATION) + self.cfg = Config.from_file(cfg_path) + self._create_model() + + if not osp.exists(model_path): + raise IOError(f'{model_path} is not exists.') + + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.pretrained_state_dict = torch.load( + model_path, map_location=self._device) + + self.load_pretrained() + self.to_device(self._device) + self.eval() + + def _create_model(self): + self.model_pos = TemporalModel( + self.cfg.model.MODEL.IN_NUM_JOINTS, + self.cfg.model.MODEL.IN_2D_FEATURE, + self.cfg.model.MODEL.OUT_NUM_JOINTS, + filter_widths=self.cfg.model.MODEL.FILTER_WIDTHS, + causal=self.cfg.model.MODEL.CAUSAL, + dropout=self.cfg.model.MODEL.DROPOUT, + channels=self.cfg.model.MODEL.CHANNELS, + dense=self.cfg.model.MODEL.DENSE) + + receptive_field = self.model_pos.receptive_field() + self.pad = (receptive_field - 1) // 2 + if self.cfg.model.MODEL.CAUSAL: + self.causal_shift = self.pad + else: + self.causal_shift = 0 + + self.model_traj = TransCan3Dkeys( + in_channels=self.cfg.model.MODEL.IN_NUM_JOINTS + * self.cfg.model.MODEL.IN_2D_FEATURE, + num_features=1024, + out_channels=self.cfg.model.MODEL.OUT_3D_FEATURE, + num_blocks=4, + time_window=receptive_field) + + def eval(self): + self.model_pos.eval() + self.model_traj.eval() + + def train(self): + self.model_pos.train() + self.model_traj.train() + + def to_device(self, device): + self.model_pos = self.model_pos.to(device) + self.model_traj = self.model_traj.to(device) + + def load_pretrained(self): + if 'model_pos' in self.pretrained_state_dict: + self.model_pos.load_state_dict( + self.pretrained_state_dict['model_pos'], strict=False) + else: + logging.error( + 'Not load model pos from pretrained_state_dict, not in pretrained_state_dict' + ) + + if 'model_traj' in self.pretrained_state_dict: + self.model_traj.load_state_dict( + self.pretrained_state_dict['model_traj'], strict=False) + else: + logging.error( + 'Not load model traj from pretrained_state_dict, not in pretrained_state_dict' + ) + logging.info('Load pretrained model done.') + + def preprocess(self, input: Dict[str, Any]) -> Dict[str, Any]: + """Proprocess of 2D input joints. + + Args: + input (Dict[str, Any]): [NUM_FRAME, NUM_JOINTS, 2], input 2d human body keypoints. + + Returns: + Dict[str, Any]: canonical 2d points and root relative joints. + """ + if 'cuda' == input.device.type: + input = input.data.cpu().numpy() + elif 'cpu' == input.device.type: + input = input.data.numpy() + pose2d = input + + pose2d_canonical = self.canonicalize_2Ds( + pose2d, self.cfg.model.INPUT.FOCAL_LENGTH, + self.cfg.model.INPUT.CENTER) + pose2d_normalized = self.normalize_screen_coordinates( + pose2d, self.cfg.model.INPUT.RES_W, self.cfg.model.INPUT.RES_H) + pose2d_rr = pose2d_normalized + pose2d_rr[:, 1:] -= pose2d_rr[:, :1] + + # expand [NUM_FRAME, NUM_JOINTS, 2] to [1, NUM_FRAME, NUM_JOINTS, 2] + pose2d_rr = np.expand_dims( + np.pad( + pose2d_rr, + ((self.pad + self.causal_shift, self.pad - self.causal_shift), + (0, 0), (0, 0)), 'edge'), + axis=0) + pose2d_canonical = np.expand_dims( + np.pad( + pose2d_canonical, + ((self.pad + self.causal_shift, self.pad - self.causal_shift), + (0, 0), (0, 0)), 'edge'), + axis=0) + pose2d_rr = torch.from_numpy(pose2d_rr.astype(np.float32)) + pose2d_canonical = torch.from_numpy( + pose2d_canonical.astype(np.float32)) + + inputs_2d = pose2d_rr.clone() + if torch.cuda.is_available(): + inputs_2d = inputs_2d.cuda(non_blocking=True) + + # Positional model + if self.cfg.model.MODEL.USE_2D_OFFSETS: + inputs_2d[:, :, 0] = 0 + else: + inputs_2d[:, :, 1:] += inputs_2d[:, :, :1] + + return { + 'inputs_2d': inputs_2d, + 'pose2d_rr': pose2d_rr, + 'pose2d_canonical': pose2d_canonical + } + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + """3D human pose estimation. + + Args: + input (Dict): + inputs_2d: [1, NUM_FRAME, NUM_JOINTS, 2] + pose2d_rr: [1, NUM_FRAME, NUM_JOINTS, 2] + pose2d_canonical: [1, NUM_FRAME, NUM_JOINTS, 2] + NUM_FRAME = max(receptive_filed + video_frame_number, video_frame_number) + + Returns: + Dict[str, Any]: + "camera_pose": Tensor, [1, NUM_FRAME, OUT_NUM_JOINTS, OUT_3D_FEATURE_DIM], + 3D human pose keypoints in camera frame. + "camera_traj": Tensor, [1, NUM_FRAME, 1, 3], + root keypoints coordinates in camere frame. + """ + inputs_2d = input['inputs_2d'] + pose2d_rr = input['pose2d_rr'] + pose2d_canonical = input['pose2d_canonical'] + with torch.no_grad(): + # predict 3D pose keypoints + predicted_3d_pos = self.model_pos(inputs_2d) + + # predict global trajectory + b1, w1, n1, d1 = inputs_2d.shape + + input_pose2d_abs = self.get_abs_2d_pts(w1, pose2d_rr, + pose2d_canonical) + b1, w1, n1, d1 = input_pose2d_abs.size() + b2, w2, n2, d2 = predicted_3d_pos.size() + + if torch.cuda.is_available(): + input_pose2d_abs = input_pose2d_abs.cuda(non_blocking=True) + + predicted_3d_traj = self.model_traj( + input_pose2d_abs.view(b1, w1, n1 * d1), + predicted_3d_pos.view(b2 * w2, n2 * d2)).view(b2, w2, -1, 3) + + predict_dict = { + KeypointsTypes.POSES_CAMERA: predicted_3d_pos, + KeypointsTypes.POSES_TRAJ: predicted_3d_traj + } + + return predict_dict + + def get_abs_2d_pts(self, input_video_frame_num, pose2d_rr, + pose2d_canonical): + pad = self.pad + w = input_video_frame_num - pad * 2 + + lst_pose2d_rr = [] + lst_pose2d_cannoical = [] + for i in range(pad, w + pad): + lst_pose2d_rr.append(pose2d_rr[:, i - pad:i + pad + 1]) + lst_pose2d_cannoical.append(pose2d_canonical[:, + i - pad:i + pad + 1]) + + input_pose2d_rr = torch.concat(lst_pose2d_cannoical, axis=0) + input_pose2d_cannoical = torch.concat(lst_pose2d_cannoical, axis=0) + + if self.cfg.model.MODEL.USE_CANONICAL_COORDS: + input_pose2d_abs = input_pose2d_cannoical.clone() + else: + input_pose2d_abs = input_pose2d_rr.clone() + input_pose2d_abs[:, :, 1:] += input_pose2d_abs[:, :, :1] + + return input_pose2d_abs + + def canonicalize_2Ds(self, pos2d, f, c): + cs = np.array([c[0], c[1]]).reshape(1, 1, 2) + fs = np.array([f[0], f[1]]).reshape(1, 1, 2) + canoical_2Ds = (pos2d - cs) / fs + return canoical_2Ds + + def normalize_screen_coordinates(self, X, w, h): + assert X.shape[-1] == 2 + + # Normalize so that [0, w] is mapped to [-1, 1], while preserving the aspect ratio + return X / w * 2 - [1, h / w] diff --git a/modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py b/modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py new file mode 100644 index 00000000..b3eac2e5 --- /dev/null +++ b/modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py @@ -0,0 +1,233 @@ +# The implementation is based on OSTrack, available at https://github.com/facebookresearch/VideoPose3D +import torch +import torch.nn as nn + + +class TemporalModelBase(nn.Module): + """ + Do not instantiate this class. + """ + + def __init__(self, num_joints_in, in_features, num_joints_out, + filter_widths, causal, dropout, channels): + super().__init__() + + # Validate input + for fw in filter_widths: + assert fw % 2 != 0, 'Only odd filter widths are supported' + + self.num_joints_in = num_joints_in + self.in_features = in_features + self.num_joints_out = num_joints_out + self.filter_widths = filter_widths + + self.drop = nn.Dropout(dropout) + self.relu = nn.ReLU(inplace=True) + + self.pad = [filter_widths[0] // 2] + self.expand_bn = nn.BatchNorm1d(channels, momentum=0.1) + self.shrink = nn.Conv1d(channels, num_joints_out * 3, 1) + + def set_bn_momentum(self, momentum): + self.expand_bn.momentum = momentum + for bn in self.layers_bn: + bn.momentum = momentum + + def receptive_field(self): + """ + Return the total receptive field of this model as # of frames. + """ + frames = 0 + for f in self.pad: + frames += f + return 1 + 2 * frames + + def total_causal_shift(self): + """ + Return the asymmetric offset for sequence padding. + The returned value is typically 0 if causal convolutions are disabled, + otherwise it is half the receptive field. + """ + frames = self.causal_shift[0] + next_dilation = self.filter_widths[0] + for i in range(1, len(self.filter_widths)): + frames += self.causal_shift[i] * next_dilation + next_dilation *= self.filter_widths[i] + return frames + + def forward(self, x): + assert len(x.shape) == 4 + assert x.shape[-2] == self.num_joints_in + assert x.shape[-1] == self.in_features + + sz = x.shape[:3] + x = x.view(x.shape[0], x.shape[1], -1) + x = x.permute(0, 2, 1) + + x = self._forward_blocks(x) + + x = x.permute(0, 2, 1) + x = x.view(sz[0], -1, self.num_joints_out, 3) + + return x + + +class TemporalModel(TemporalModelBase): + """ + Reference 3D pose estimation model with temporal convolutions. + This implementation can be used for all use-cases. + """ + + def __init__(self, + num_joints_in, + in_features, + num_joints_out, + filter_widths, + causal=False, + dropout=0.25, + channels=1024, + dense=False): + """ + Initialize this model. + + Arguments: + num_joints_in -- number of input joints (e.g. 17 for Human3.6M) + in_features -- number of input features for each joint (typically 2 for 2D input) + num_joints_out -- number of output joints (can be different than input) + filter_widths -- list of convolution widths, which also determines the # of blocks and receptive field + causal -- use causal convolutions instead of symmetric convolutions (for real-time applications) + dropout -- dropout probability + channels -- number of convolution channels + dense -- use regular dense convolutions instead of dilated convolutions (ablation experiment) + """ + super().__init__(num_joints_in, in_features, num_joints_out, + filter_widths, causal, dropout, channels) + + self.expand_conv = nn.Conv1d( + num_joints_in * in_features, + channels, + filter_widths[0], + bias=False) + + layers_conv = [] + layers_bn = [] + + self.causal_shift = [(filter_widths[0]) // 2 if causal else 0] + next_dilation = filter_widths[0] + for i in range(1, len(filter_widths)): + self.pad.append((filter_widths[i] - 1) * next_dilation // 2) + self.causal_shift.append((filter_widths[i] // 2 + * next_dilation) if causal else 0) + + layers_conv.append( + nn.Conv1d( + channels, + channels, + filter_widths[i] if not dense else (2 * self.pad[-1] + 1), + dilation=next_dilation if not dense else 1, + bias=False)) + layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) + layers_conv.append( + nn.Conv1d(channels, channels, 1, dilation=1, bias=False)) + layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) + + next_dilation *= filter_widths[i] + + self.layers_conv = nn.ModuleList(layers_conv) + self.layers_bn = nn.ModuleList(layers_bn) + + def _forward_blocks(self, x): + x = self.drop(self.relu(self.expand_bn(self.expand_conv(x)))) + for i in range(len(self.pad) - 1): + pad = self.pad[i + 1] + shift = self.causal_shift[i + 1] + res = x[:, :, pad + shift:x.shape[2] - pad + shift] + x = self.drop( + self.relu(self.layers_bn[2 * i](self.layers_conv[2 * i](x)))) + x = res + self.drop( + self.relu(self.layers_bn[2 * i + 1]( + self.layers_conv[2 * i + 1](x)))) + + x = self.shrink(x) + return x + + +# regression of the trajectory +class TransCan3Dkeys(nn.Module): + + def __init__(self, + in_channels=74, + num_features=256, + out_channels=44, + time_window=10, + num_blocks=2): + super().__init__() + self.in_channels = in_channels + self.num_features = num_features + self.out_channels = out_channels + self.num_blocks = num_blocks + self.time_window = time_window + + self.expand_bn = nn.BatchNorm1d(self.num_features, momentum=0.1) + self.conv1 = nn.Sequential( + nn.ReplicationPad1d(1), + nn.Conv1d( + self.in_channels, self.num_features, kernel_size=3, + bias=False), self.expand_bn, nn.ReLU(inplace=True), + nn.Dropout(p=0.25)) + self._make_blocks() + self.pad = nn.ReplicationPad1d(4) + self.relu = nn.ReLU(inplace=True) + self.drop = nn.Dropout(p=0.25) + self.reduce = nn.Conv1d( + self.num_features, self.num_features, kernel_size=self.time_window) + self.embedding_3d_1 = nn.Linear(in_channels // 2 * 3, 500) + self.embedding_3d_2 = nn.Linear(500, 500) + self.LReLU1 = nn.LeakyReLU() + self.LReLU2 = nn.LeakyReLU() + self.LReLU3 = nn.LeakyReLU() + self.out1 = nn.Linear(self.num_features + 500, self.num_features) + self.out2 = nn.Linear(self.num_features, self.out_channels) + + def _make_blocks(self): + layers_conv = [] + layers_bn = [] + for i in range(self.num_blocks): + layers_conv.append( + nn.Conv1d( + self.num_features, + self.num_features, + kernel_size=5, + bias=False, + dilation=2)) + layers_bn.append(nn.BatchNorm1d(self.num_features)) + self.layers_conv = nn.ModuleList(layers_conv) + self.layers_bn = nn.ModuleList(layers_bn) + + def set_bn_momentum(self, momentum): + self.expand_bn.momentum = momentum + for bn in self.layers_bn: + bn.momentum = momentum + + def forward(self, p2ds, p3d): + """ + Args: + x - (B x T x J x C) + """ + B, T, C = p2ds.shape + x = p2ds.permute((0, 2, 1)) + x = self.conv1(x) + for i in range(self.num_blocks): + pre = x + x = self.pad(x) + x = self.layers_conv[i](x) + x = self.layers_bn[i](x) + x = self.drop(self.relu(x)) + x = pre + x + x_2d = self.relu(self.reduce(x)) + x_2d = x_2d.view(B, -1) + x_3d = self.LReLU1(self.embedding_3d_1(p3d)) + x = torch.cat((x_2d, x_3d), 1) + x = self.LReLU3(self.out1(x)) + x = self.out2(x) + return x diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 2edd76a2..622d9034 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -189,6 +189,16 @@ TASK_OUTPUTS = { Tasks.body_2d_keypoints: [OutputKeys.POSES, OutputKeys.SCORES, OutputKeys.BOXES], + # 3D human body keypoints detection result for single sample + # { + # "poses": [ + # [[x, y, z]*17], + # [[x, y, z]*17], + # [[x, y, z]*17] + # ] + # } + Tasks.body_3d_keypoints: [OutputKeys.POSES], + # video single object tracking result for single video # { # "boxes": [ diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index fa6705a7..f8f679e6 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -89,6 +89,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_diffusion_text-to-image-synthesis_tiny'), Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints, 'damo/cv_hrnetv2w32_body-2d-keypoints_image'), + Tasks.body_3d_keypoints: (Pipelines.body_3d_keypoints, + 'damo/cv_canonical_body-3d-keypoints_video'), Tasks.face_detection: (Pipelines.face_detection, 'damo/cv_resnet_facedetection_scrfd10gkps'), Tasks.face_recognition: (Pipelines.face_recognition, diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 2c062226..640ffd4c 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: from .action_recognition_pipeline import ActionRecognitionPipeline from .animal_recognition_pipeline import AnimalRecognitionPipeline from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline + from .body_3d_keypoints_pipeline import Body3DKeypointsPipeline from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline from .crowd_counting_pipeline import CrowdCountingPipeline from .image_detection_pipeline import ImageDetectionPipeline @@ -46,6 +47,7 @@ else: 'action_recognition_pipeline': ['ActionRecognitionPipeline'], 'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], 'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'], + 'body_3d_keypoints_pipeline': ['Body3DKeypointsPipeline'], 'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'], 'crowd_counting_pipeline': ['CrowdCountingPipeline'], 'image_detection_pipeline': ['ImageDetectionPipeline'], diff --git a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py new file mode 100644 index 00000000..e9e4e9e8 --- /dev/null +++ b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py @@ -0,0 +1,213 @@ +import os +import os.path as osp +from typing import Any, Dict, List, Union + +import cv2 +import numpy as np +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.body_3d_keypoints.body_3d_pose import ( + BodyKeypointsDetection3D, KeypointsTypes) +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Input, Model, Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def convert_2_h36m(joints, joints_nbr=15): + lst_mappings = [[0, 8], [1, 7], [2, 12], [3, 13], [4, 14], [5, 9], [6, 10], + [7, 11], [8, 1], [9, 2], [10, 3], [11, 4], [12, 5], + [13, 6], [14, 0]] + nbr, dim = joints.shape + h36m_joints = np.zeros((nbr, dim)) + for mapping in lst_mappings: + h36m_joints[mapping[1]] = joints[mapping[0]] + + if joints_nbr == 17: + lst_mappings_17 = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], + [5, 5], [6, 6], [7, 8], [8, 10], [9, 11], + [10, 12], [11, 13], [12, 14], [13, 15], + [14, 16]]) + h36m_joints_17 = np.zeros((17, 2)) + h36m_joints_17[lst_mappings_17[:, 1]] = h36m_joints[lst_mappings_17[:, + 0]] + h36m_joints_17[7] = (h36m_joints_17[0] + h36m_joints_17[8]) * 0.5 + h36m_joints_17[9] = (h36m_joints_17[8] + h36m_joints_17[10]) * 0.5 + h36m_joints = h36m_joints_17 + + return h36m_joints + + +def smooth_pts(cur_pts, pre_pts, bbox, smooth_x=15.0, smooth_y=15.0): + if pre_pts is None: + return cur_pts + + w, h = bbox[1] - bbox[0] + if w == 0 or h == 0: + return cur_pts + + size_pre = len(pre_pts) + size_cur = len(cur_pts) + if (size_pre == 0 or size_cur == 0): + return cur_pts + + factor_x = -(smooth_x / w) + factor_y = -(smooth_y / w) + + for i in range(size_cur): + w_x = np.exp(factor_x * np.abs(cur_pts[i][0] - pre_pts[i][0])) + w_y = np.exp(factor_y * np.abs(cur_pts[i][1] - pre_pts[i][1])) + cur_pts[i][0] = (1.0 - w_x) * cur_pts[i][0] + w_x * pre_pts[i][0] + cur_pts[i][1] = (1.0 - w_y) * cur_pts[i][1] + w_y * pre_pts[i][1] + return cur_pts + + +def smoothing(lst_kps, lst_bboxes, smooth_x=15.0, smooth_y=15.0): + assert lst_kps.shape[0] == lst_bboxes.shape[0] + + lst_smoothed_kps = [] + prev_pts = None + for i in range(lst_kps.shape[0]): + smoothed_cur_kps = smooth_pts(lst_kps[i], prev_pts, + lst_bboxes[i][0:-1].reshape(2, 2), + smooth_x, smooth_y) + lst_smoothed_kps.append(smoothed_cur_kps) + prev_pts = smoothed_cur_kps + + return np.array(lst_smoothed_kps) + + +def convert_2_h36m_data(lst_kps, lst_bboxes, joints_nbr=15): + lst_kps = lst_kps.squeeze() + lst_bboxes = lst_bboxes.squeeze() + + assert lst_kps.shape[0] == lst_bboxes.shape[0] + + lst_kps = smoothing(lst_kps, lst_bboxes) + + keypoints = [] + for i in range(lst_kps.shape[0]): + h36m_joints_2d = convert_2_h36m(lst_kps[i], joints_nbr=joints_nbr) + keypoints.append(h36m_joints_2d) + return keypoints + + +@PIPELINES.register_module( + Tasks.body_3d_keypoints, module_name=Pipelines.body_3d_keypoints) +class Body3DKeypointsPipeline(Pipeline): + + def __init__(self, model: Union[str, BodyKeypointsDetection3D], **kwargs): + """Human body 3D pose estimation. + + Args: + model (Union[str, BodyKeypointsDetection3D]): model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + self.keypoint_model_3d = model if isinstance( + model, BodyKeypointsDetection3D) else Model.from_pretrained(model) + self.keypoint_model_3d.eval() + + # init human body 2D keypoints detection pipeline + self.human_body_2d_kps_det_pipeline = 'damo/cv_hrnetv2w32_body-2d-keypoints_image' + self.human_body_2d_kps_detector = pipeline( + Tasks.body_2d_keypoints, + model=self.human_body_2d_kps_det_pipeline, + device='gpu' if torch.cuda.is_available() else 'cpu') + + def preprocess(self, input: Input) -> Dict[str, Any]: + video_frames = self.read_video_frames(input) + if 0 == len(video_frames): + res = {'success': False, 'msg': 'get video frame failed.'} + return res + + all_2d_poses = [] + all_boxes_with_socre = [] + max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME # max video frame number to be predicted 3D joints + for i, frame in enumerate(video_frames): + kps_2d = self.human_body_2d_kps_detector(frame) + box = kps_2d['boxes'][ + 0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox + pose = kps_2d['poses'][0] # keypoints: [15, 2] + score = kps_2d['scores'][0] # keypoints: [15, 2] + all_2d_poses.append(pose) + all_boxes_with_socre.append( + list(np.array(box).reshape( + (-1))) + [score]) # construct to list with shape [5] + if (i + 1) >= max_frame: + break + + all_2d_poses_np = np.array(all_2d_poses).reshape( + (len(all_2d_poses), 15, + 2)) # 15: 2d keypoints number, 2: keypoint coordinate (x, y) + all_boxes_np = np.array(all_boxes_with_socre).reshape( + (len(all_boxes_with_socre), 5)) # [x1, y1, x2, y2, score] + + kps_2d_h36m_17 = convert_2_h36m_data( + all_2d_poses_np, + all_boxes_np, + joints_nbr=self.keypoint_model_3d.cfg.model.MODEL.IN_NUM_JOINTS) + kps_2d_h36m_17 = np.array(kps_2d_h36m_17) + res = {'success': True, 'input_2d_pts': kps_2d_h36m_17} + return res + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + if not input['success']: + res = {'success': False, 'msg': 'preprocess failed.'} + return res + + input_2d_pts = input['input_2d_pts'] + outputs = self.keypoint_model_3d.preprocess(input_2d_pts) + outputs = self.keypoint_model_3d.forward(outputs) + res = dict({'success': True}, **outputs) + return res + + def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: + res = {OutputKeys.POSES: []} + + if not input['success']: + pass + else: + poses = input[KeypointsTypes.POSES_CAMERA] + res = {OutputKeys.POSES: poses.data.cpu().numpy()} + return res + + def read_video_frames(self, video_url: Union[str, cv2.VideoCapture]): + """Read video from local video file or from a video stream URL. + + Args: + video_url (str or cv2.VideoCapture): Video path or video stream. + + Raises: + Exception: Open video fail. + + Returns: + [nd.array]: List of video frames. + """ + frames = [] + if isinstance(video_url, str): + cap = cv2.VideoCapture(video_url) + if not cap.isOpened(): + raise Exception( + 'modelscope error: %s cannot be decoded by OpenCV.' % + (video_url)) + else: + cap = video_url + + max_frame_num = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME + frame_idx = 0 + while True: + ret, frame = cap.read() + if not ret: + break + frame_idx += 1 + frames.append(frame) + if frame_idx >= max_frame_num: + break + cap.release() + return frames diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 52c08594..2141a012 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -24,6 +24,7 @@ class CVTasks(object): human_object_interaction = 'human-object-interaction' face_image_generation = 'face-image-generation' body_2d_keypoints = 'body-2d-keypoints' + body_3d_keypoints = 'body-3d-keypoints' general_recognition = 'general-recognition' image_classification = 'image-classification' diff --git a/tests/pipelines/test_body_3d_keypoints.py b/tests/pipelines/test_body_3d_keypoints.py new file mode 100644 index 00000000..50426414 --- /dev/null +++ b/tests/pipelines/test_body_3d_keypoints.py @@ -0,0 +1,49 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import pdb +import unittest + +import cv2 +import numpy as np +from PIL import Image + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class Body3DKeypointsTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_canonical_body-3d-keypoints_video' + self.test_video = 'data/test/videos/Walking.54138969.mp4' + + def pipeline_inference(self, pipeline: Pipeline, pipeline_input): + output = pipeline(pipeline_input) + poses = np.array(output[OutputKeys.POSES]) + print(f'result 3d points shape {poses.shape}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub_with_video_file(self): + body_3d_keypoints = pipeline( + Tasks.body_3d_keypoints, model=self.model_id) + self.pipeline_inference(body_3d_keypoints, self.test_video) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub_with_video_stream(self): + body_3d_keypoints = pipeline(Tasks.body_3d_keypoints) + cap = cv2.VideoCapture(self.test_video) + if not cap.isOpened(): + raise Exception('modelscope error: %s cannot be decoded by OpenCV.' + % (self.test_video)) + self.pipeline_inference(body_3d_keypoints, cap) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + body_3d_keypoints = pipeline(Tasks.body_3d_keypoints) + self.pipeline_inference(body_3d_keypoints, self.test_video) + + +if __name__ == '__main__': + unittest.main()